@@ -103,14 +103,20 @@ def forward(self, x, ds_y):
103
103
104
104
105
105
class BasicBlockERes2Net (nn .Module ):
106
- expansion = 2
107
106
108
- def __init__ (self , in_planes , planes , stride = 1 , baseWidth = 32 , scale = 2 ):
107
+ def __init__ (self ,
108
+ in_planes ,
109
+ planes ,
110
+ stride = 1 ,
111
+ baseWidth = 32 ,
112
+ scale = 2 ,
113
+ expansion = 2 ):
109
114
super (BasicBlockERes2Net , self ).__init__ ()
110
115
width = int (math .floor (planes * (baseWidth / 64.0 )))
111
116
self .conv1 = conv1x1 (in_planes , width * scale , stride )
112
117
self .bn1 = nn .BatchNorm2d (width * scale )
113
118
self .nums = scale
119
+ self .expansion = expansion
114
120
115
121
convs = []
116
122
bns = []
@@ -162,14 +168,20 @@ def forward(self, x):
162
168
163
169
164
170
class BasicBlockERes2Net_diff_AFF (nn .Module ):
165
- expansion = 2
166
171
167
- def __init__ (self , in_planes , planes , stride = 1 , baseWidth = 32 , scale = 2 ):
172
+ def __init__ (self ,
173
+ in_planes ,
174
+ planes ,
175
+ stride = 1 ,
176
+ baseWidth = 32 ,
177
+ scale = 2 ,
178
+ expansion = 2 ):
168
179
super (BasicBlockERes2Net_diff_AFF , self ).__init__ ()
169
180
width = int (math .floor (planes * (baseWidth / 64.0 )))
170
181
self .conv1 = conv1x1 (in_planes , width * scale , stride )
171
182
self .bn1 = nn .BatchNorm2d (width * scale )
172
183
self .nums = scale
184
+ self .expansion = expansion
173
185
174
186
# to meet the torch.jit.script export requirements
175
187
self .conv2_1 = conv3x3 (width , width )
@@ -232,6 +244,9 @@ class ERes2Net(nn.Module):
232
244
def __init__ (self ,
233
245
m_channels ,
234
246
num_blocks ,
247
+ baseWidth = 32 ,
248
+ scale = 2 ,
249
+ expansion = 2 ,
235
250
block = BasicBlockERes2Net ,
236
251
block_fuse = BasicBlockERes2Net_diff_AFF ,
237
252
feat_dim = 80 ,
@@ -244,6 +259,7 @@ def __init__(self,
244
259
self .embed_dim = embed_dim
245
260
self .stats_dim = int (feat_dim / 8 ) * m_channels * 8
246
261
self .two_emb_layer = two_emb_layer
262
+ self .expansion = expansion
247
263
248
264
self .conv1 = nn .Conv2d (1 ,
249
265
m_channels ,
@@ -255,48 +271,59 @@ def __init__(self,
255
271
self .layer1 = self ._make_layer (block ,
256
272
m_channels ,
257
273
num_blocks [0 ],
258
- stride = 1 )
274
+ stride = 1 ,
275
+ baseWidth = baseWidth ,
276
+ scale = scale ,
277
+ expansion = expansion )
259
278
self .layer2 = self ._make_layer (block ,
260
279
m_channels * 2 ,
261
280
num_blocks [1 ],
262
- stride = 2 )
281
+ stride = 2 ,
282
+ baseWidth = baseWidth ,
283
+ scale = scale ,
284
+ expansion = expansion )
263
285
self .layer3 = self ._make_layer (block_fuse ,
264
286
m_channels * 4 ,
265
287
num_blocks [2 ],
266
- stride = 2 )
288
+ stride = 2 ,
289
+ baseWidth = baseWidth ,
290
+ scale = scale ,
291
+ expansion = expansion )
267
292
self .layer4 = self ._make_layer (block_fuse ,
268
293
m_channels * 8 ,
269
294
num_blocks [3 ],
270
- stride = 2 )
295
+ stride = 2 ,
296
+ baseWidth = baseWidth ,
297
+ scale = scale ,
298
+ expansion = expansion )
271
299
272
300
# Downsampling module for each layer
273
- self .layer1_downsample = nn .Conv2d (m_channels * 2 ,
274
- m_channels * 4 ,
301
+ self .layer1_downsample = nn .Conv2d (m_channels * expansion ,
302
+ m_channels * expansion * 2 ,
275
303
kernel_size = 3 ,
276
304
stride = 2 ,
277
305
padding = 1 ,
278
306
bias = False )
279
- self .layer2_downsample = nn .Conv2d (m_channels * 4 ,
280
- m_channels * 8 ,
307
+ self .layer2_downsample = nn .Conv2d (m_channels * expansion * 2 ,
308
+ m_channels * expansion * 4 ,
281
309
kernel_size = 3 ,
282
310
padding = 1 ,
283
311
stride = 2 ,
284
312
bias = False )
285
- self .layer3_downsample = nn .Conv2d (m_channels * 8 ,
286
- m_channels * 16 ,
313
+ self .layer3_downsample = nn .Conv2d (m_channels * expansion * 4 ,
314
+ m_channels * expansion * 8 ,
287
315
kernel_size = 3 ,
288
316
padding = 1 ,
289
317
stride = 2 ,
290
318
bias = False )
291
319
292
320
# Bottom-up fusion module
293
- self .fuse_mode12 = AFF (channels = m_channels * 4 )
294
- self .fuse_mode123 = AFF (channels = m_channels * 8 )
295
- self .fuse_mode1234 = AFF (channels = m_channels * 16 )
321
+ self .fuse_mode12 = AFF (channels = m_channels * expansion * 2 )
322
+ self .fuse_mode123 = AFF (channels = m_channels * expansion * 4 )
323
+ self .fuse_mode1234 = AFF (channels = m_channels * expansion * 8 )
296
324
297
325
self .pool = getattr (pooling_layers ,
298
- pooling_func )(in_dim = self .stats_dim *
299
- block .expansion )
326
+ pooling_func )(in_dim = self .stats_dim * expansion )
300
327
self .pool_out_dim = self .pool .get_out_dim ()
301
328
self .seg_1 = nn .Linear (self .pool_out_dim , embed_dim )
302
329
if self .two_emb_layer :
@@ -306,12 +333,21 @@ def __init__(self,
306
333
self .seg_bn_1 = nn .Identity ()
307
334
self .seg_2 = nn .Identity ()
308
335
309
- def _make_layer (self , block , planes , num_blocks , stride ):
336
+ def _make_layer (self ,
337
+ block ,
338
+ planes ,
339
+ num_blocks ,
340
+ stride ,
341
+ baseWidth = 32 ,
342
+ scale = 2 ,
343
+ expansion = 2 ):
310
344
strides = [stride ] + [1 ] * (num_blocks - 1 )
311
345
layers = []
312
346
for stride in strides :
313
- layers .append (block (self .in_planes , planes , stride ))
314
- self .in_planes = planes * block .expansion
347
+ layers .append (
348
+ block (self .in_planes , planes , stride , baseWidth , scale ,
349
+ expansion ))
350
+ self .in_planes = planes * self .expansion
315
351
return nn .Sequential (* layers )
316
352
317
353
def forward (self , x ):
@@ -362,6 +398,23 @@ def ERes2Net34_Large(feat_dim,
362
398
two_emb_layer = two_emb_layer )
363
399
364
400
401
+ def ERes2Net34_aug (feat_dim ,
402
+ embed_dim ,
403
+ pooling_func = 'TSTP' ,
404
+ two_emb_layer = False ,
405
+ expansion = 4 ,
406
+ baseWidth = 24 ,
407
+ scale = 3 ):
408
+ return ERes2Net (64 , [3 , 4 , 6 , 3 ],
409
+ expansion = expansion ,
410
+ baseWidth = baseWidth ,
411
+ scale = scale ,
412
+ feat_dim = feat_dim ,
413
+ embed_dim = embed_dim ,
414
+ pooling_func = pooling_func ,
415
+ two_emb_layer = two_emb_layer )
416
+
417
+
365
418
if __name__ == '__main__' :
366
419
x = torch .zeros (1 , 200 , 80 )
367
420
model = ERes2Net34_Base (feat_dim = 80 , embed_dim = 512 , two_emb_layer = False )
0 commit comments