1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
| import torch from torch import nn
from einops import rearrange, repeat from einops.layers.torch import Rearrange
def pair(t): '''如果 t 不是元祖则返回 (t,t) 构成的元组 这里其实是获取图像尺寸,如果传入参数 img_size 是 (h,w) 之间返回 否则默认是正方形的图像,进行转换 ''' return t if isinstance(t, tuple) else (t, t)
class FeedForward(nn.Module): '''前馈神经网络 ''' def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) )
def forward(self, x): return self.net(x)
class Attention(nn.Module): '''注意力层 ''' def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim)
self.heads = heads self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity()
def forward(self, x): x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots) attn = self.dropout(attn)
out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out)
class Transformer(nn.Module): '''Transformer Encoder 堆叠 depth 个 含跳接(残差)的Attention+FFN ''' def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ]))
def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x
return self.norm(x)
class ViT(nn.Module): def __init__(self, *, image_size, patch_size, channels = 3, num_classes = 64, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width assert pool in {'cls', 'mean','all'}, 'pool type must be cls (cls token), mean (mean pooling) or all (keep all)'
self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), )
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x)
x = self.transformer(x)
if self.pool == 'mean': x = x.mean(dim = 1) return self.mlp_head(x) if self.pool == 'cls': x[:, 0] return self.mlp_head(x)
return x, self.pos_embedding
|