class ViT(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072):
super().__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
p = self.patch_size
x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.mlp_head(x[:, 0])
return x
paper: vit
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
推薦閱讀更多精彩內容
- 目錄: ?ViT (原始ViT, 發表于2020年10月) ?ImageNet1K上更好的ViT基線 (發表于20...
- 時間過去,論文還相當于爛尾,還有很多篇想寫的,不是漂浮,要落地,要靜下來,要沉下心,要專注,要深入,不要任何的表面...
- @[toc] 參考文獻 An Image is Worth 16x16 Words: Transformers f...