Skip to content

Commit 6ec8fda

Browse files
committed
make sure global average pool can be used for vivit in place of cls token
1 parent 13fabf9 commit 6ec8fda

File tree

4 files changed

+35
-21
lines changed

4 files changed

+35
-21
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.37.0',
6+
version = '0.37.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/simple_vit_3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
114114
nn.Linear(dim, num_classes)
115115
)
116116

117-
def forward(self, img):
118-
*_, h, w, dtype = *img.shape, img.dtype
117+
def forward(self, video):
118+
*_, h, w, dtype = *video.shape, video.dtype
119119

120-
x = self.to_patch_embedding(img)
120+
x = self.to_patch_embedding(video)
121121
pe = posemb_sincos_3d(x)
122122
x = rearrange(x, 'b ... d -> b (...) d') + pe
123123

vit_pytorch/vit_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
112112
nn.Linear(dim, num_classes)
113113
)
114114

115-
def forward(self, img):
116-
x = self.to_patch_embedding(img)
115+
def forward(self, video):
116+
x = self.to_patch_embedding(video)
117117
b, n, _ = x.shape
118118

119119
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)

vit_pytorch/vivit.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import torch
22
from torch import nn
33

4-
from einops import rearrange, repeat
4+
from einops import rearrange, repeat, reduce
55
from einops.layers.torch import Rearrange
66

77
# helpers
88

9+
def exists(val):
10+
return val is not None
11+
912
def pair(t):
1013
return t if isinstance(t, tuple) else (t, t)
1114

@@ -106,20 +109,25 @@ def __init__(
106109
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
107110
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
108111

109-
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
112+
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
113+
num_frame_patches = (frames // frame_patch_size)
114+
110115
patch_dim = channels * patch_height * patch_width * frame_patch_size
111116

112117
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
113118

119+
self.global_average_pool = pool == 'mean'
120+
114121
self.to_patch_embedding = nn.Sequential(
115122
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
116123
nn.Linear(patch_dim, dim),
117124
)
118125

119-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
126+
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
120127
self.dropout = nn.Dropout(emb_dropout)
121-
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim))
122-
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim))
128+
129+
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
130+
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
123131

124132
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
125133
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
@@ -132,13 +140,16 @@ def __init__(
132140
nn.Linear(dim, num_classes)
133141
)
134142

135-
def forward(self, img):
136-
x = self.to_patch_embedding(img)
143+
def forward(self, video):
144+
x = self.to_patch_embedding(video)
137145
b, f, n, _ = x.shape
138146

139-
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
140-
x = torch.cat((spatial_cls_tokens, x), dim = 2)
141-
x += self.pos_embedding[:, :(n + 1)]
147+
x = x + self.pos_embedding
148+
149+
if exists(self.spatial_cls_token):
150+
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
151+
x = torch.cat((spatial_cls_tokens, x), dim = 2)
152+
142153
x = self.dropout(x)
143154

144155
x = rearrange(x, 'b f n d -> (b f) n d')
@@ -149,21 +160,24 @@ def forward(self, img):
149160

150161
x = rearrange(x, '(b f) n d -> b f n d', b = b)
151162

152-
# excise out the spatial cls tokens for temporal attention
163+
# excise out the spatial cls tokens or average pool for temporal attention
153164

154-
x = x[:, :, 0]
165+
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
155166

156167
# append temporal CLS tokens
157168

158-
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
169+
if exists(self.temporal_cls_token):
170+
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
159171

160-
x = torch.cat((temporal_cls_tokens, x), dim = 1)
172+
x = torch.cat((temporal_cls_tokens, x), dim = 1)
161173

162174
# attend across time
163175

164176
x = self.temporal_transformer(x)
165177

166-
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
178+
# excise out temporal cls token or average pool
179+
180+
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
167181

168182
x = self.to_latent(x)
169183
return self.mlp_head(x)

0 commit comments

Comments
 (0)