1
1
import torch
2
2
from torch import nn
3
3
4
- from einops import rearrange , repeat
4
+ from einops import rearrange , repeat , reduce
5
5
from einops .layers .torch import Rearrange
6
6
7
7
# helpers
8
8
9
+ def exists (val ):
10
+ return val is not None
11
+
9
12
def pair (t ):
10
13
return t if isinstance (t , tuple ) else (t , t )
11
14
@@ -106,20 +109,25 @@ def __init__(
106
109
assert image_height % patch_height == 0 and image_width % patch_width == 0 , 'Image dimensions must be divisible by the patch size.'
107
110
assert frames % frame_patch_size == 0 , 'Frames must be divisible by frame patch size'
108
111
109
- num_patches = (image_height // patch_height ) * (image_width // patch_width ) * (frames // frame_patch_size )
112
+ num_image_patches = (image_height // patch_height ) * (image_width // patch_width )
113
+ num_frame_patches = (frames // frame_patch_size )
114
+
110
115
patch_dim = channels * patch_height * patch_width * frame_patch_size
111
116
112
117
assert pool in {'cls' , 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)'
113
118
119
+ self .global_average_pool = pool == 'mean'
120
+
114
121
self .to_patch_embedding = nn .Sequential (
115
122
Rearrange ('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)' , p1 = patch_height , p2 = patch_width , pf = frame_patch_size ),
116
123
nn .Linear (patch_dim , dim ),
117
124
)
118
125
119
- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + 1 , dim ))
126
+ self .pos_embedding = nn .Parameter (torch .randn (1 , num_frame_patches , num_image_patches , dim ))
120
127
self .dropout = nn .Dropout (emb_dropout )
121
- self .spatial_cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
122
- self .temporal_cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
128
+
129
+ self .spatial_cls_token = nn .Parameter (torch .randn (1 , 1 , dim )) if not self .global_average_pool else None
130
+ self .temporal_cls_token = nn .Parameter (torch .randn (1 , 1 , dim )) if not self .global_average_pool else None
123
131
124
132
self .spatial_transformer = Transformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout )
125
133
self .temporal_transformer = Transformer (dim , temporal_depth , heads , dim_head , mlp_dim , dropout )
@@ -132,13 +140,16 @@ def __init__(
132
140
nn .Linear (dim , num_classes )
133
141
)
134
142
135
- def forward (self , img ):
136
- x = self .to_patch_embedding (img )
143
+ def forward (self , video ):
144
+ x = self .to_patch_embedding (video )
137
145
b , f , n , _ = x .shape
138
146
139
- spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = b , f = f )
140
- x = torch .cat ((spatial_cls_tokens , x ), dim = 2 )
141
- x += self .pos_embedding [:, :(n + 1 )]
147
+ x = x + self .pos_embedding
148
+
149
+ if exists (self .spatial_cls_token ):
150
+ spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = b , f = f )
151
+ x = torch .cat ((spatial_cls_tokens , x ), dim = 2 )
152
+
142
153
x = self .dropout (x )
143
154
144
155
x = rearrange (x , 'b f n d -> (b f) n d' )
@@ -149,21 +160,24 @@ def forward(self, img):
149
160
150
161
x = rearrange (x , '(b f) n d -> b f n d' , b = b )
151
162
152
- # excise out the spatial cls tokens for temporal attention
163
+ # excise out the spatial cls tokens or average pool for temporal attention
153
164
154
- x = x [:, :, 0 ]
165
+ x = x [:, :, 0 ] if not self . global_average_pool else reduce ( x , 'b f n d -> b f d' , 'mean' )
155
166
156
167
# append temporal CLS tokens
157
168
158
- temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = b )
169
+ if exists (self .temporal_cls_token ):
170
+ temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = b )
159
171
160
- x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
172
+ x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
161
173
162
174
# attend across time
163
175
164
176
x = self .temporal_transformer (x )
165
177
166
- x = x .mean (dim = 1 ) if self .pool == 'mean' else x [:, 0 ]
178
+ # excise out temporal cls token or average pool
179
+
180
+ x = x [:, 0 ] if not self .global_average_pool else reduce (x , 'b f d -> b d' , 'mean' )
167
181
168
182
x = self .to_latent (x )
169
183
return self .mlp_head (x )
0 commit comments