本文已参与「新人创造礼」活动,一起开启创造之路。

arxiv.org/abs/2010.11…

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE ViT(Vision Transformer)

ViT(Vision Transformer)

网络结构

ViT便是将Transformer运用到了图画范畴。

在ViT中只要Attention is all you need论文中的编码器部分,也便是只要编码器的Transformer结构,而不含有解码器的Transformer结构,也便是只要自留意力self-Attention,不含有交叉留意力cross-attention。

ViT(Vision Transformer)

留意:代码里Transformer编码器前有个Dropout层,后有一个Layer Norm层。

MLP Head在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成,而在迁移到ImageNet1K上或许自己的数据集上时,只要一个Linear层,所以就当做一个线性层就行。

ViT(Vision Transformer)

ViT(Vision Transformer)

ViT整体流程

输入是一张图画,将图画切分成多个块(patch),然后通过线性层或许嵌入层(Embedding)得到token序列,然后加上方位编码,再通过一系列Transformer结构,最终通过MLP得到最终的猜测。

这里的方位编码和原本的Transformer中的不同,运用的是可学习的方位编码,而不是核算得来的固定的方位编码。

论文里有通过试验对各种方位编码进行了测试

ViT(Vision Transformer)

ViT中自留意力代码完成如下(Pytorch):

class Attention(layers.Layer):
    k_ini = initializers.GlorotUniform()
    b_ini = initializers.Zeros()
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.,
                 name=None):
        super(Attention, self).__init__(name=name)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.attn_drop = layers.Dropout(attn_drop_ratio)
        self.proj = layers.Dense(dim, name="out",
                                 kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.proj_drop = layers.Dropout(proj_drop_ratio)
    def call(self, inputs, training=None):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = inputs.shape
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        qkv = self.qkv(inputs)
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        qkv = tf.reshape(qkv, [B, N, 3, self.num_heads, C // self.num_heads])
        # transpose: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]
        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale
        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)
        # multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        x = tf.matmul(attn, v)
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        x = tf.transpose(x, [0, 2, 1, 3])
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = tf.reshape(x, [B, N, C])
        x = self.proj(x)
        x = self.proj_drop(x, training=training)
        return x

完成起来比较简单,当然如果不想自己完成,也能够用mmlab完成的,我用的是mmcv里的,在包mmseg.models.utils

self.attn = SelfAttentionBlock(
    key_in_channels=dim,
    query_in_channels=dim,
    channels=dim,
    out_channels=dim,
    share_key_query=False,
    query_downsample=None,
    key_downsample=None,
    key_query_num_convs=1,
    value_out_num_convs=1,
    key_query_norm=True,
    value_out_norm=True,
    matmul_norm=True,
    with_out=True,
    conv_cfg=None,
    norm_cfg=norm_cfg,
    act_cfg=act_cfg)

既能够用作自留意力x1 = self.attn(x, x),也能够用作交叉留意力x2 = self.cross_attn(x, cls)

如果pytorch版本是最新的(PyTorch 1.13),还能够运用官方完成的:

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)

试验成果

ViT(Vision Transformer)

ViT(Vision Transformer)