CvT: 如何將卷積的優(yōu)勢融入Transformer

【GiantPandaCV導(dǎo)語】與之前BoTNet不同,CvT雖然題目中有卷積的字樣,但是實(shí)際總體來說依然是以Transformer Block為主的,在Token的處理方面引入了卷積,從而為模型帶來的局部性。最終CvT最高拿下了87.7%的Top1準(zhǔn)確率。

引言

CvT架構(gòu)的Motivation也是將局部性引入Vision Transformer架構(gòu)中,期望通過引入局部性得到更高的性能和效率權(quán)衡。因此我們主要關(guān)注CvT是如何引入局部性的。具體來說提出了兩點(diǎn)改進(jìn):

  • Convolutional token embedding
  • Convolutional Projection

通過以上改進(jìn),模型不僅具有卷積的優(yōu)勢(局部感受野、權(quán)重共享、空間下采樣等特性帶來的優(yōu)勢),如平移不變形、尺度不變性、旋轉(zhuǎn)不變性等,也保持了Self Attention的優(yōu)勢,如動(dòng)態(tài)注意力、全局語義信息、更強(qiáng)的泛化能力等。

展開一點(diǎn)講,Convolutional Vision Transformer有兩點(diǎn)核心:

  • 第一步,參考CNN的架構(gòu),將Transformer也設(shè)計(jì)為多階段的層次架構(gòu),每個(gè)stage之前使用convolutional token embedding,通過使用卷積+layer normalization能夠?qū)崿F(xiàn)降維的功能(注:逐漸降低序列長度的同時(shí),增加每個(gè)token的維度,可以類比卷積中feature map砍半,通道數(shù)增加的操作)
  • 第二步,使用Convolutional Projection取代原來的Linear Projection,該模塊實(shí)際使用的是深度可分離卷積實(shí)現(xiàn),這樣也能有效捕獲局部語義信息。

需要注意的是:CvT去掉了Positional Embedding模塊,發(fā)現(xiàn)對(duì)模型性能沒有任何影響。認(rèn)為可以簡化架構(gòu)的設(shè)計(jì),并且可以在分辨率變化的情況下更容易適配。

比較

在相關(guān)工作部分,CvT總結(jié)了一個(gè)表格,比較方便對(duì)比:

image

方法

在引言部分已經(jīng)講得比較詳細(xì)了,下面對(duì)照架構(gòu)圖復(fù)盤一下(用盡可能通俗的語言描述):

image
  • 綠色框是conv token embedding操作,通俗來講,使用了超大卷積核來提升局部性不足的問題。
  • 右圖藍(lán)色框中展示的是改進(jìn)的self attention,通俗來講,使用了non local的操作,使用深度可分離卷積取代MLP做Projection,如下圖所示:
image
  • 如圖(a)所示,Vision Transformer中使用的是MLP進(jìn)行Linear Projection, 這樣的信息是全局性的,但是計(jì)算量比較大。
  • 如圖(b)所示,使用卷積進(jìn)行映射,這種操作類似Non Local Network,使用卷積進(jìn)行映射。
  • 如圖(c)所示,使用的是帶stride的卷積進(jìn)行壓縮,這樣做是處于對(duì)效率的考量,token數(shù)量可以降低四倍,會(huì)帶來一定的性能損失。

Positional embedding探討:

由于Convolutional Projection在每個(gè)Transformer Block中都是用,配合Convolutional Token Embedding操作,能夠給模型足夠的能力來建模局部空間關(guān)系,因此可以去掉Transformer中的Positional Embedding操作。從下表發(fā)現(xiàn),pe對(duì)模型性能影響不大。

image

與其他工作的對(duì)比:

  • 同期工作1:Tokens-to-Tokens ViT: 使用Progressive Tokenization整合臨近token,使用Transformer-based骨干網(wǎng)絡(luò)具有局部性的同時(shí),還能降低token序列長度。
  • 區(qū)別:CvT使用的是multi-stage的過程,token長度降低的同時(shí),其維度在增加,從而保證模型的容量。同時(shí)計(jì)算量相比T2T有所改善。
  • 同期工作2:Pyramid Vision Transformer(PVT): 引入了金字塔架構(gòu),使得PVT可以作為Backbone應(yīng)用于Dense prediction任務(wù)中。
  • 區(qū)別:CvT也使用了金字塔架構(gòu),區(qū)別在于CvT中提出使用stride卷積來實(shí)現(xiàn)空間降采樣,進(jìn)一步融合了局部信息。

最終模型架構(gòu)如下:

image

實(shí)驗(yàn)

image

左圖中令人感興趣的是BiT,這篇是谷歌的文章big transfer,探究CNN架構(gòu)在大規(guī)模數(shù)據(jù)與訓(xùn)練的效果,可以看出即便是純CNN架構(gòu)模型參數(shù)量也可以非常巨大,而Vision Transformer還有CvT等在同等精度下模型參數(shù)量遠(yuǎn)小于BiT,這一定程度上說明了Transformer結(jié)合CNN在數(shù)據(jù)量足夠的情況下性能可以非常可觀,要比單純CNN架構(gòu)的模型性能更優(yōu)。

右圖展示了CvT和幾種vision transformer架構(gòu)的性能比較,可見CvT在權(quán)衡方面做的非常不錯(cuò)。

與SOTA比較:

image

有趣的是CvT-13-NAS也采用了搜索的方法DA-NAS,主要搜索對(duì)象是key和value的stride,以及MLP的Expansion Ratio, 最終搜索的結(jié)果要比Baseline略好。

在無需JFT數(shù)據(jù)集的情況下,CvT最高調(diào)整可以達(dá)到87.7%的top1 準(zhǔn)確率。

其他數(shù)據(jù)集結(jié)果:

image

消融實(shí)驗(yàn)

image
image

代碼

Convolutional Token Embedding代碼實(shí)現(xiàn):可以看出,實(shí)際上就是大卷積核+大Stride的滑動(dòng)引入的局部性。

class ConvEmbed(nn.Module):
    """ Image to Conv Embedding
    """
    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x

Convolutional Projection代碼實(shí)現(xiàn),具體看_build_projection函數(shù):

class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

參考

https://github.com/microsoft/CvT/blob/main/lib/models/cls_cvt.py

https://arxiv.org/pdf/2103.15808.pdf

https://zhuanlan.zhihu.com/p/142864566

筆者在cifar10數(shù)據(jù)集上修改了CvT中的Stride等參數(shù),在不用任何數(shù)據(jù)增強(qiáng)和Trick的情況下得到了下圖結(jié)果,Top1為84.74。雖然看上去性能比較差,但是這還沒有調(diào)參以及加上數(shù)據(jù)增強(qiáng)方法,只訓(xùn)練了200個(gè)epoch的結(jié)果。

python train.py --model 'cvt' --name "cvt" --sched 'cosine' --epochs 200 --lr 0.01

感興趣的可以點(diǎn)擊下面鏈接調(diào)參:

https://github.com/pprp/pytorch-cifar-model-zoo

image
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 229,117評(píng)論 6 537
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 98,860評(píng)論 3 423
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 177,128評(píng)論 0 381
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,291評(píng)論 1 315
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 72,025評(píng)論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 55,421評(píng)論 1 324
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,477評(píng)論 3 444
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,642評(píng)論 0 289
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 49,177評(píng)論 1 335
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 40,970評(píng)論 3 356
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 43,157評(píng)論 1 371
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,717評(píng)論 5 362
  • 正文 年R本政府宣布,位于F島的核電站,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 44,410評(píng)論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,821評(píng)論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,053評(píng)論 1 289
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 51,896評(píng)論 3 395
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 48,157評(píng)論 2 375

推薦閱讀更多精彩內(nèi)容