【GiantPandaCV導(dǎo)語】與之前BoTNet不同,CvT雖然題目中有卷積的字樣,但是實(shí)際總體來說依然是以Transformer Block為主的,在Token的處理方面引入了卷積,從而為模型帶來的局部性。最終CvT最高拿下了87.7%的Top1準(zhǔn)確率。
引言
CvT架構(gòu)的Motivation也是將局部性引入Vision Transformer架構(gòu)中,期望通過引入局部性得到更高的性能和效率權(quán)衡。因此我們主要關(guān)注CvT是如何引入局部性的。具體來說提出了兩點(diǎn)改進(jìn):
- Convolutional token embedding
- Convolutional Projection
通過以上改進(jìn),模型不僅具有卷積的優(yōu)勢(局部感受野、權(quán)重共享、空間下采樣等特性帶來的優(yōu)勢),如平移不變形、尺度不變性、旋轉(zhuǎn)不變性等,也保持了Self Attention的優(yōu)勢,如動(dòng)態(tài)注意力、全局語義信息、更強(qiáng)的泛化能力等。
展開一點(diǎn)講,Convolutional Vision Transformer有兩點(diǎn)核心:
- 第一步,參考CNN的架構(gòu),將Transformer也設(shè)計(jì)為多階段的層次架構(gòu),每個(gè)stage之前使用convolutional token embedding,通過使用卷積+layer normalization能夠?qū)崿F(xiàn)降維的功能(注:逐漸降低序列長度的同時(shí),增加每個(gè)token的維度,可以類比卷積中feature map砍半,通道數(shù)增加的操作)
- 第二步,使用Convolutional Projection取代原來的Linear Projection,該模塊實(shí)際使用的是深度可分離卷積實(shí)現(xiàn),這樣也能有效捕獲局部語義信息。
需要注意的是:CvT去掉了Positional Embedding模塊,發(fā)現(xiàn)對(duì)模型性能沒有任何影響。認(rèn)為可以簡化架構(gòu)的設(shè)計(jì),并且可以在分辨率變化的情況下更容易適配。
比較
在相關(guān)工作部分,CvT總結(jié)了一個(gè)表格,比較方便對(duì)比:
方法
在引言部分已經(jīng)講得比較詳細(xì)了,下面對(duì)照架構(gòu)圖復(fù)盤一下(用盡可能通俗的語言描述):
- 綠色框是conv token embedding操作,通俗來講,使用了超大卷積核來提升局部性不足的問題。
- 右圖藍(lán)色框中展示的是改進(jìn)的self attention,通俗來講,使用了non local的操作,使用深度可分離卷積取代MLP做Projection,如下圖所示:
- 如圖(a)所示,Vision Transformer中使用的是MLP進(jìn)行Linear Projection, 這樣的信息是全局性的,但是計(jì)算量比較大。
- 如圖(b)所示,使用卷積進(jìn)行映射,這種操作類似Non Local Network,使用卷積進(jìn)行映射。
- 如圖(c)所示,使用的是帶stride的卷積進(jìn)行壓縮,這樣做是處于對(duì)效率的考量,token數(shù)量可以降低四倍,會(huì)帶來一定的性能損失。
Positional embedding探討:
由于Convolutional Projection在每個(gè)Transformer Block中都是用,配合Convolutional Token Embedding操作,能夠給模型足夠的能力來建模局部空間關(guān)系,因此可以去掉Transformer中的Positional Embedding操作。從下表發(fā)現(xiàn),pe對(duì)模型性能影響不大。
與其他工作的對(duì)比:
- 同期工作1:Tokens-to-Tokens ViT: 使用Progressive Tokenization整合臨近token,使用Transformer-based骨干網(wǎng)絡(luò)具有局部性的同時(shí),還能降低token序列長度。
- 區(qū)別:CvT使用的是multi-stage的過程,token長度降低的同時(shí),其維度在增加,從而保證模型的容量。同時(shí)計(jì)算量相比T2T有所改善。
- 同期工作2:Pyramid Vision Transformer(PVT): 引入了金字塔架構(gòu),使得PVT可以作為Backbone應(yīng)用于Dense prediction任務(wù)中。
- 區(qū)別:CvT也使用了金字塔架構(gòu),區(qū)別在于CvT中提出使用stride卷積來實(shí)現(xiàn)空間降采樣,進(jìn)一步融合了局部信息。
最終模型架構(gòu)如下:
實(shí)驗(yàn)
左圖中令人感興趣的是BiT,這篇是谷歌的文章big transfer,探究CNN架構(gòu)在大規(guī)模數(shù)據(jù)與訓(xùn)練的效果,可以看出即便是純CNN架構(gòu)模型參數(shù)量也可以非常巨大,而Vision Transformer還有CvT等在同等精度下模型參數(shù)量遠(yuǎn)小于BiT,這一定程度上說明了Transformer結(jié)合CNN在數(shù)據(jù)量足夠的情況下性能可以非常可觀,要比單純CNN架構(gòu)的模型性能更優(yōu)。
右圖展示了CvT和幾種vision transformer架構(gòu)的性能比較,可見CvT在權(quán)衡方面做的非常不錯(cuò)。
與SOTA比較:
有趣的是CvT-13-NAS也采用了搜索的方法DA-NAS,主要搜索對(duì)象是key和value的stride,以及MLP的Expansion Ratio, 最終搜索的結(jié)果要比Baseline略好。
在無需JFT數(shù)據(jù)集的情況下,CvT最高調(diào)整可以達(dá)到87.7%的top1 準(zhǔn)確率。
其他數(shù)據(jù)集結(jié)果:
消融實(shí)驗(yàn)
代碼
Convolutional Token Embedding代碼實(shí)現(xiàn):可以看出,實(shí)際上就是大卷積核+大Stride的滑動(dòng)引入的局部性。
class ConvEmbed(nn.Module):
""" Image to Conv Embedding
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm:
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
return x
Convolutional Projection代碼實(shí)現(xiàn),具體看_build_projection函數(shù):
class Attention(nn.Module):
def __init__(self,
dim_in,
dim_out,
num_heads,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
method='dw_bn',
kernel_size=3,
stride_kv=1,
stride_q=1,
padding_kv=1,
padding_q=1,
with_cls_token=True,
**kwargs
):
super().__init__()
self.stride_kv = stride_kv
self.stride_q = stride_q
self.dim = dim_out
self.num_heads = num_heads
# head_dim = self.qkv_dim // num_heads
self.scale = dim_out ** -0.5
self.with_cls_token = with_cls_token
self.conv_proj_q = self._build_projection(
dim_in, dim_out, kernel_size, padding_q,
stride_q, 'linear' if method == 'avg' else method
)
self.conv_proj_k = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv, method
)
self.conv_proj_v = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv, method
)
self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim_out, dim_out)
self.proj_drop = nn.Dropout(proj_drop)
def _build_projection(self,
dim_in,
dim_out,
kernel_size,
padding,
stride,
method):
if method == 'dw_bn':
proj = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=False,
groups=dim_in
)),
('bn', nn.BatchNorm2d(dim_in)),
('rearrage', Rearrange('b c h w -> b (h w) c')),
]))
elif method == 'avg':
proj = nn.Sequential(OrderedDict([
('avg', nn.AvgPool2d(
kernel_size=kernel_size,
padding=padding,
stride=stride,
ceil_mode=True
)),
('rearrage', Rearrange('b c h w -> b (h w) c')),
]))
elif method == 'linear':
proj = None
else:
raise ValueError('Unknown method ({})'.format(method))
return proj
def forward_conv(self, x, h, w):
if self.with_cls_token:
cls_token, x = torch.split(x, [1, h*w], 1)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
if self.conv_proj_q is not None:
q = self.conv_proj_q(x)
else:
q = rearrange(x, 'b c h w -> b (h w) c')
if self.conv_proj_k is not None:
k = self.conv_proj_k(x)
else:
k = rearrange(x, 'b c h w -> b (h w) c')
if self.conv_proj_v is not None:
v = self.conv_proj_v(x)
else:
v = rearrange(x, 'b c h w -> b (h w) c')
if self.with_cls_token:
q = torch.cat((cls_token, q), dim=1)
k = torch.cat((cls_token, k), dim=1)
v = torch.cat((cls_token, v), dim=1)
return q, k, v
def forward(self, x, h, w):
if (
self.conv_proj_q is not None
or self.conv_proj_k is not None
or self.conv_proj_v is not None
):
q, k, v = self.forward_conv(x, h, w)
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
x = rearrange(x, 'b h t d -> b t (h d)')
x = self.proj(x)
x = self.proj_drop(x)
return x
參考
https://github.com/microsoft/CvT/blob/main/lib/models/cls_cvt.py
https://arxiv.org/pdf/2103.15808.pdf
https://zhuanlan.zhihu.com/p/142864566
筆者在cifar10數(shù)據(jù)集上修改了CvT中的Stride等參數(shù),在不用任何數(shù)據(jù)增強(qiáng)和Trick的情況下得到了下圖結(jié)果,Top1為84.74。雖然看上去性能比較差,但是這還沒有調(diào)參以及加上數(shù)據(jù)增強(qiáng)方法,只訓(xùn)練了200個(gè)epoch的結(jié)果。
python train.py --model 'cvt' --name "cvt" --sched 'cosine' --epochs 200 --lr 0.01
感興趣的可以點(diǎn)擊下面鏈接調(diào)參:
https://github.com/pprp/pytorch-cifar-model-zoo