Skip to content

Commit 4ac2c5f

Browse files
stephanpeitzfacebook-github-bot
authored andcommitted
Implementation of the WeCNLP abstract "Cross+Self-Attention for Transformer Models" (#1097)
Summary: This PR implements a new attention module which combines cross-attention (encoder-decoder attention) and the decoder self-attention. This work was accepted as an abstract at WeCNLP 2019 (https://www.wecnlp.ai/wecnlp-2019). Cross+Self-Attention reduces the amount of parameter and increases the inference speed without any degradation in translation quality. More details can be found in the attached [abstract](https://github.com/pytorch/fairseq/files/3561282/paper.pdf) Pull Request resolved: #1097 Differential Revision: D17653168 Pulled By: myleott fbshipit-source-id: deb834c2c78a229d7418ffbfea20ba3ce252991c
1 parent ea1a410 commit 4ac2c5f

File tree

4 files changed

+120
-12
lines changed

4 files changed

+120
-12
lines changed

fairseq/models/transformer.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ def add_args(parser):
122122
'Must be used with adaptive_loss criterion'),
123123
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
124124
help='sets adaptive softmax dropout for the tail projections')
125+
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
126+
parser.add_argument('--no-cross-attention', default=False, action='store_true',
127+
help='do not perform cross-attention')
128+
parser.add_argument('--cross-self-attention', default=False, action='store_true',
129+
help='perform cross+self-attention')
130+
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
131+
help='perform layer-wise attention (cross-attention or cross+self-attention)')
125132
# fmt: on
126133

127134
@classmethod
@@ -180,7 +187,12 @@ def build_encoder(cls, args, src_dict, embed_tokens):
180187

181188
@classmethod
182189
def build_decoder(cls, args, tgt_dict, embed_tokens):
183-
return TransformerDecoder(args, tgt_dict, embed_tokens)
190+
return TransformerDecoder(
191+
args,
192+
tgt_dict,
193+
embed_tokens,
194+
no_encoder_attn=getattr(args, 'no_cross_attention', False),
195+
)
184196

185197

186198
class TransformerEncoder(FairseqEncoder):
@@ -211,6 +223,8 @@ def __init__(self, args, dictionary, embed_tokens):
211223
learned=args.encoder_learned_pos,
212224
) if not args.no_token_positional_embeddings else None
213225

226+
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
227+
214228
self.layers = nn.ModuleList([])
215229
self.layers.extend([
216230
TransformerEncoderLayer(args)
@@ -230,21 +244,29 @@ def forward_embedding(self, src_tokens):
230244
x = F.dropout(x, p=self.dropout, training=self.training)
231245
return x, embed
232246

233-
def forward(self, src_tokens, src_lengths, cls_input=None):
247+
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
234248
"""
235249
Args:
236250
src_tokens (LongTensor): tokens in the source language of shape
237251
`(batch, src_len)`
238252
src_lengths (torch.LongTensor): lengths of each source sentence of
239253
shape `(batch)`
254+
return_all_hiddens (bool, optional): also return all of the
255+
intermediate hidden states (default: False).
240256
241257
Returns:
242258
dict:
243259
- **encoder_out** (Tensor): the last encoder layer's output of
244260
shape `(src_len, batch, embed_dim)`
245261
- **encoder_padding_mask** (ByteTensor): the positions of
246262
padding elements of shape `(batch, src_len)`
263+
- **encoder_states** (List[Tensor]): all intermediate
264+
hidden states of shape `(src_len, batch, embed_dim)`.
265+
Only populated if *return_all_hiddens* is True.
247266
"""
267+
if self.layer_wise_attention:
268+
return_all_hiddens = True
269+
248270
x, encoder_embedding = self.forward_embedding(src_tokens)
249271

250272
# B x T x C -> T x B x C
@@ -255,17 +277,24 @@ def forward(self, src_tokens, src_lengths, cls_input=None):
255277
if not encoder_padding_mask.any():
256278
encoder_padding_mask = None
257279

280+
encoder_states = [] if return_all_hiddens else None
281+
258282
# encoder layers
259283
for layer in self.layers:
260284
x = layer(x, encoder_padding_mask)
285+
if return_all_hiddens:
286+
encoder_states.append(x)
261287

262288
if self.layer_norm:
263289
x = self.layer_norm(x)
290+
if return_all_hiddens:
291+
encoder_states[-1] = x
264292

265293
return {
266294
'encoder_out': x, # T x B x C
267295
'encoder_padding_mask': encoder_padding_mask, # B x T
268296
'encoder_embedding': encoder_embedding, # B x T x C
297+
'encoder_states': encoder_states, # List[T x B x C]
269298
}
270299

271300
def reorder_encoder_out(self, encoder_out, new_order):
@@ -285,6 +314,9 @@ def reorder_encoder_out(self, encoder_out, new_order):
285314
if encoder_out['encoder_padding_mask'] is not None:
286315
encoder_out['encoder_padding_mask'] = \
287316
encoder_out['encoder_padding_mask'].index_select(0, new_order)
317+
if encoder_out.get('encoder_states', None) is not None:
318+
for idx, state in enumerate(encoder_out['encoder_states']):
319+
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
288320
return encoder_out
289321

290322
def max_positions(self):
@@ -293,6 +325,14 @@ def max_positions(self):
293325
return self.max_source_positions
294326
return min(self.max_source_positions, self.embed_positions.max_positions())
295327

328+
def buffered_future_mask(self, tensor):
329+
dim = tensor.size(0)
330+
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
331+
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
332+
if self._future_mask.size(0) < dim:
333+
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
334+
return self._future_mask[:dim, :dim]
335+
296336
def upgrade_state_dict_named(self, state_dict, name):
297337
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
298338
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
@@ -350,6 +390,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
350390
learned=args.decoder_learned_pos,
351391
) if not args.no_token_positional_embeddings else None
352392

393+
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
394+
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
395+
353396
self.layers = nn.ModuleList([])
354397
self.layers.extend([
355398
TransformerDecoderLayer(args, no_encoder_attn)
@@ -435,14 +478,26 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta
435478

436479
inner_states = [x]
437480

481+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
482+
if not self_attn_padding_mask.any() and not self.cross_self_attention:
483+
self_attn_padding_mask = None
484+
438485
# decoder layers
439-
for layer in self.layers:
486+
for idx, layer in enumerate(self.layers):
487+
encoder_state = None
488+
if encoder_out is not None:
489+
if self.layer_wise_attention:
490+
encoder_state = encoder_out['encoder_states'][idx]
491+
else:
492+
encoder_state = encoder_out['encoder_out']
493+
440494
x, attn = layer(
441495
x,
442-
encoder_out['encoder_out'] if encoder_out is not None else None,
496+
encoder_state,
443497
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
444498
incremental_state,
445499
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
500+
self_attn_padding_mask=self_attn_padding_mask,
446501
)
447502
inner_states.append(x)
448503

@@ -553,6 +608,9 @@ def base_architecture(args):
553608
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
554609
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
555610
args.adaptive_input = getattr(args, 'adaptive_input', False)
611+
args.no_cross_attention = getattr(args, 'no_cross_attention', False)
612+
args.cross_self_attention = getattr(args, 'cross_self_attention', False)
613+
args.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
556614

557615
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
558616
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)

fairseq/modules/multihead_attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,15 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
186186
v = prev_value
187187
else:
188188
v = torch.cat((prev_value, v), dim=1)
189+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
190+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
191+
if static_kv:
192+
key_padding_mask = prev_key_padding_mask
193+
else:
194+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
189195
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
190196
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
197+
saved_state['prev_key_padding_mask'] = key_padding_mask
191198

192199
self._set_input_buffer(incremental_state, saved_state)
193200

@@ -311,7 +318,8 @@ def reorder_incremental_state(self, incremental_state, new_order):
311318
input_buffer = self._get_input_buffer(incremental_state)
312319
if input_buffer is not None:
313320
for k in input_buffer.keys():
314-
input_buffer[k] = input_buffer[k].index_select(0, new_order)
321+
if input_buffer[k] is not None:
322+
input_buffer[k] = input_buffer[k].index_select(0, new_order)
315323
self._set_input_buffer(incremental_state, input_buffer)
316324

317325
def _get_input_buffer(self, incremental_state):

fairseq/modules/transformer_layer.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
67
import torch.nn as nn
78
import torch.nn.functional as F
89
from fairseq import utils
@@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module):
134135
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
135136
super().__init__()
136137
self.embed_dim = args.decoder_embed_dim
138+
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
137139
self.self_attn = MultiheadAttention(
138140
embed_dim=self.embed_dim,
139141
num_heads=args.decoder_attention_heads,
140142
dropout=args.attention_dropout,
141143
add_bias_kv=add_bias_kv,
142144
add_zero_attn=add_zero_attn,
143-
self_attention=True
145+
self_attention=not self.cross_self_attention,
144146
)
145147
self.dropout = args.dropout
146148
self.activation_fn = utils.get_activation_fn(
@@ -208,13 +210,27 @@ def forward(
208210
if prev_self_attn_state is not None:
209211
if incremental_state is None:
210212
incremental_state = {}
211-
prev_key, prev_value = prev_self_attn_state
213+
prev_key, prev_value = prev_self_attn_state[:2]
212214
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
215+
if len(prev_self_attn_state) >= 3:
216+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
213217
self.self_attn._set_input_buffer(incremental_state, saved_state)
218+
219+
if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
220+
if self_attn_mask is not None:
221+
self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
222+
if self_attn_padding_mask is not None:
223+
if encoder_padding_mask is None:
224+
encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
225+
self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
226+
y = torch.cat((encoder_out, x), dim=0)
227+
else:
228+
y = x
229+
214230
x, attn = self.self_attn(
215231
query=x,
216-
key=x,
217-
value=x,
232+
key=y,
233+
value=y,
218234
key_padding_mask=self_attn_padding_mask,
219235
incremental_state=incremental_state,
220236
need_weights=False,
@@ -230,9 +246,12 @@ def forward(
230246
if prev_attn_state is not None:
231247
if incremental_state is None:
232248
incremental_state = {}
233-
prev_key, prev_value = prev_attn_state
249+
prev_key, prev_value = prev_attn_state[:2]
234250
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
251+
if len(prev_attn_state) >= 3:
252+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
235253
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
254+
236255
x, attn = self.encoder_attn(
237256
query=x,
238257
key=encoder_out,
@@ -256,7 +275,10 @@ def forward(
256275
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
257276
if self.onnx_trace and incremental_state is not None:
258277
saved_state = self.self_attn._get_input_buffer(incremental_state)
259-
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
278+
if self_attn_padding_mask is not None:
279+
self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
280+
else:
281+
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
260282
return x, attn, self_attn_state
261283
return x, attn
262284

tests/test_binaries.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ def test_transformer(self):
154154
], run_validation=True)
155155
generate_main(data_dir)
156156

157+
def test_transformer_cross_self_attention(self):
158+
with contextlib.redirect_stdout(StringIO()):
159+
with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir:
160+
create_dummy_data(data_dir)
161+
preprocess_translation_data(data_dir)
162+
train_translation_model(data_dir, 'transformer_iwslt_de_en', [
163+
'--encoder-layers', '2',
164+
'--decoder-layers', '2',
165+
'--encoder-embed-dim', '8',
166+
'--decoder-embed-dim', '8',
167+
'--decoder-embed-dim', '8',
168+
'--no-cross-attention',
169+
'--cross-self-attention',
170+
'--layer-wise-attention',
171+
], run_validation=True)
172+
generate_main(data_dir, extra_flags=[])
173+
157174
def test_lightconv(self):
158175
with contextlib.redirect_stdout(StringIO()):
159176
with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
@@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
543560

544561

545562
def generate_main(data_dir, extra_flags=None):
563+
if extra_flags is None:
564+
extra_flags = [
565+
'--print-alignment',
566+
]
546567
generate_parser = options.get_generation_parser()
547568
generate_args = options.parse_args_and_arch(
548569
generate_parser,
@@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None):
554575
'--max-len-b', '5',
555576
'--gen-subset', 'valid',
556577
'--no-progress-bar',
557-
'--print-alignment',
558578
] + (extra_flags or []),
559579
)
560580

0 commit comments

Comments
 (0)