@@ -122,6 +122,13 @@ def add_args(parser):
122
122
'Must be used with adaptive_loss criterion' ),
123
123
parser .add_argument ('--adaptive-softmax-dropout' , type = float , metavar = 'D' ,
124
124
help = 'sets adaptive softmax dropout for the tail projections' )
125
+ # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
126
+ parser .add_argument ('--no-cross-attention' , default = False , action = 'store_true' ,
127
+ help = 'do not perform cross-attention' )
128
+ parser .add_argument ('--cross-self-attention' , default = False , action = 'store_true' ,
129
+ help = 'perform cross+self-attention' )
130
+ parser .add_argument ('--layer-wise-attention' , default = False , action = 'store_true' ,
131
+ help = 'perform layer-wise attention (cross-attention or cross+self-attention)' )
125
132
# fmt: on
126
133
127
134
@classmethod
@@ -180,7 +187,12 @@ def build_encoder(cls, args, src_dict, embed_tokens):
180
187
181
188
@classmethod
182
189
def build_decoder (cls , args , tgt_dict , embed_tokens ):
183
- return TransformerDecoder (args , tgt_dict , embed_tokens )
190
+ return TransformerDecoder (
191
+ args ,
192
+ tgt_dict ,
193
+ embed_tokens ,
194
+ no_encoder_attn = getattr (args , 'no_cross_attention' , False ),
195
+ )
184
196
185
197
186
198
class TransformerEncoder (FairseqEncoder ):
@@ -211,6 +223,8 @@ def __init__(self, args, dictionary, embed_tokens):
211
223
learned = args .encoder_learned_pos ,
212
224
) if not args .no_token_positional_embeddings else None
213
225
226
+ self .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
227
+
214
228
self .layers = nn .ModuleList ([])
215
229
self .layers .extend ([
216
230
TransformerEncoderLayer (args )
@@ -230,21 +244,29 @@ def forward_embedding(self, src_tokens):
230
244
x = F .dropout (x , p = self .dropout , training = self .training )
231
245
return x , embed
232
246
233
- def forward (self , src_tokens , src_lengths , cls_input = None ):
247
+ def forward (self , src_tokens , src_lengths , cls_input = None , return_all_hiddens = False ):
234
248
"""
235
249
Args:
236
250
src_tokens (LongTensor): tokens in the source language of shape
237
251
`(batch, src_len)`
238
252
src_lengths (torch.LongTensor): lengths of each source sentence of
239
253
shape `(batch)`
254
+ return_all_hiddens (bool, optional): also return all of the
255
+ intermediate hidden states (default: False).
240
256
241
257
Returns:
242
258
dict:
243
259
- **encoder_out** (Tensor): the last encoder layer's output of
244
260
shape `(src_len, batch, embed_dim)`
245
261
- **encoder_padding_mask** (ByteTensor): the positions of
246
262
padding elements of shape `(batch, src_len)`
263
+ - **encoder_states** (List[Tensor]): all intermediate
264
+ hidden states of shape `(src_len, batch, embed_dim)`.
265
+ Only populated if *return_all_hiddens* is True.
247
266
"""
267
+ if self .layer_wise_attention :
268
+ return_all_hiddens = True
269
+
248
270
x , encoder_embedding = self .forward_embedding (src_tokens )
249
271
250
272
# B x T x C -> T x B x C
@@ -255,17 +277,24 @@ def forward(self, src_tokens, src_lengths, cls_input=None):
255
277
if not encoder_padding_mask .any ():
256
278
encoder_padding_mask = None
257
279
280
+ encoder_states = [] if return_all_hiddens else None
281
+
258
282
# encoder layers
259
283
for layer in self .layers :
260
284
x = layer (x , encoder_padding_mask )
285
+ if return_all_hiddens :
286
+ encoder_states .append (x )
261
287
262
288
if self .layer_norm :
263
289
x = self .layer_norm (x )
290
+ if return_all_hiddens :
291
+ encoder_states [- 1 ] = x
264
292
265
293
return {
266
294
'encoder_out' : x , # T x B x C
267
295
'encoder_padding_mask' : encoder_padding_mask , # B x T
268
296
'encoder_embedding' : encoder_embedding , # B x T x C
297
+ 'encoder_states' : encoder_states , # List[T x B x C]
269
298
}
270
299
271
300
def reorder_encoder_out (self , encoder_out , new_order ):
@@ -285,6 +314,9 @@ def reorder_encoder_out(self, encoder_out, new_order):
285
314
if encoder_out ['encoder_padding_mask' ] is not None :
286
315
encoder_out ['encoder_padding_mask' ] = \
287
316
encoder_out ['encoder_padding_mask' ].index_select (0 , new_order )
317
+ if encoder_out .get ('encoder_states' , None ) is not None :
318
+ for idx , state in enumerate (encoder_out ['encoder_states' ]):
319
+ encoder_out ['encoder_states' ][idx ] = state .index_select (1 , new_order )
288
320
return encoder_out
289
321
290
322
def max_positions (self ):
@@ -293,6 +325,14 @@ def max_positions(self):
293
325
return self .max_source_positions
294
326
return min (self .max_source_positions , self .embed_positions .max_positions ())
295
327
328
+ def buffered_future_mask (self , tensor ):
329
+ dim = tensor .size (0 )
330
+ if not hasattr (self , '_future_mask' ) or self ._future_mask is None or self ._future_mask .device != tensor .device :
331
+ self ._future_mask = torch .triu (utils .fill_with_neg_inf (tensor .new (dim , dim )), 1 )
332
+ if self ._future_mask .size (0 ) < dim :
333
+ self ._future_mask = torch .triu (utils .fill_with_neg_inf (self ._future_mask .resize_ (dim , dim )), 1 )
334
+ return self ._future_mask [:dim , :dim ]
335
+
296
336
def upgrade_state_dict_named (self , state_dict , name ):
297
337
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
298
338
if isinstance (self .embed_positions , SinusoidalPositionalEmbedding ):
@@ -350,6 +390,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
350
390
learned = args .decoder_learned_pos ,
351
391
) if not args .no_token_positional_embeddings else None
352
392
393
+ self .cross_self_attention = getattr (args , 'cross_self_attention' , False )
394
+ self .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
395
+
353
396
self .layers = nn .ModuleList ([])
354
397
self .layers .extend ([
355
398
TransformerDecoderLayer (args , no_encoder_attn )
@@ -435,14 +478,26 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta
435
478
436
479
inner_states = [x ]
437
480
481
+ self_attn_padding_mask = prev_output_tokens .eq (self .padding_idx )
482
+ if not self_attn_padding_mask .any () and not self .cross_self_attention :
483
+ self_attn_padding_mask = None
484
+
438
485
# decoder layers
439
- for layer in self .layers :
486
+ for idx , layer in enumerate (self .layers ):
487
+ encoder_state = None
488
+ if encoder_out is not None :
489
+ if self .layer_wise_attention :
490
+ encoder_state = encoder_out ['encoder_states' ][idx ]
491
+ else :
492
+ encoder_state = encoder_out ['encoder_out' ]
493
+
440
494
x , attn = layer (
441
495
x ,
442
- encoder_out [ 'encoder_out' ] if encoder_out is not None else None ,
496
+ encoder_state ,
443
497
encoder_out ['encoder_padding_mask' ] if encoder_out is not None else None ,
444
498
incremental_state ,
445
499
self_attn_mask = self .buffered_future_mask (x ) if incremental_state is None else None ,
500
+ self_attn_padding_mask = self_attn_padding_mask ,
446
501
)
447
502
inner_states .append (x )
448
503
@@ -553,6 +608,9 @@ def base_architecture(args):
553
608
args .share_all_embeddings = getattr (args , 'share_all_embeddings' , False )
554
609
args .no_token_positional_embeddings = getattr (args , 'no_token_positional_embeddings' , False )
555
610
args .adaptive_input = getattr (args , 'adaptive_input' , False )
611
+ args .no_cross_attention = getattr (args , 'no_cross_attention' , False )
612
+ args .cross_self_attention = getattr (args , 'cross_self_attention' , False )
613
+ args .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
556
614
557
615
args .decoder_output_dim = getattr (args , 'decoder_output_dim' , args .decoder_embed_dim )
558
616
args .decoder_input_dim = getattr (args , 'decoder_input_dim' , args .decoder_embed_dim )
0 commit comments