Skip to content

Commit e23e5ea

Browse files
ngoyal2707facebook-github-bot
authored andcommitted
XLM-R code and model release (#900)
Summary: TODO: 1) Need to update bibtex entry 2) Need to upload models, spm_vocab and dict.txt to public s3 location. For Future: 1) I will probably add instructions to finetune on XNLI and NER, POS etc. but currently no timeline for that. Pull Request resolved: fairinternal/fairseq-py#900 Reviewed By: myleott Differential Revision: D18333076 Pulled By: myleott fbshipit-source-id: 3f3d3716fcc41c78d2dd4525f60b519abbd0459c
1 parent 68dd3e1 commit e23e5ea

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ modeling and other text generation tasks.
66

77
### What's New:
88

9+
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
910
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
1011
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
1112
- July 2019: fairseq relicensed under MIT license

examples/roberta/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l
88

99
### What's New:
1010

11+
- November 2019: Multilingual encoder (XLM-RoBERTa) is available [XLM-R](https://github.com/pytorch/fairseq/examples/xlmr).
1112
- September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
1213
- August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
1314
- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset).

examples/xlmr/README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa)
2+
3+
## Introduction
4+
5+
XLM-R (XLM-RoBERTa) is scaled cross lingual sentence encoder. It is trained on `2.5T` of data across `100` languages data filtered from Common Crawl. XLM-R achieves state-of-the-arts results on multiple cross lingual benchmarks.
6+
7+
## Pre-trained models
8+
9+
Model | Description | # params | Download
10+
---|---|---|---
11+
`xlmr.base.v0` | XLM-R using the BERT-base architecture | 250M | [xlm.base.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz)
12+
`xlmr.large.v0` | XLM-R using the BERT-large architecture | 560M | [xlm.large.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz)
13+
14+
(Note: The above models are still under training, we will update the weights, once fully trained, the results are based on the above checkpoints.)
15+
16+
## Results
17+
18+
**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
19+
20+
Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
21+
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
22+
`roberta.large.mnli` _(TRANSLATE-TEST)_ | 91.3 | 82.9 | 84.3 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
23+
`xlmr.large.v0` _(TRANSLATE-TRAIN-ALL)_ | 88.7 | 85.2 | 85.6 | 84.6 | 83.6 | 85.5 | 82.4 | 81.6 | 80.9 | 83.4 | 80.9 | 83.3 | 79.8 | 75.9 | 74.3
24+
25+
## Example usage
26+
27+
##### Load XLM-R from torch.hub (PyTorch >= 1.1):
28+
```python
29+
import torch
30+
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large.v0')
31+
xlmr.eval() # disable dropout (or leave in train mode to finetune)
32+
```
33+
34+
##### Load XLM-R (for PyTorch 1.0 or custom models):
35+
```python
36+
# Download xlmr.large model
37+
wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz
38+
tar -xzvf xlmr.large.v0.tar.gz
39+
40+
# Load the model in fairseq
41+
from fairseq.models.roberta import XLMRModel
42+
xlmr = XLMRModel.from_pretrained('/path/to/xlmr.large.v0', checkpoint_file='model.pt')
43+
xlmr.eval() # disable dropout (or leave in train mode to finetune)
44+
```
45+
46+
##### Apply Byte-Pair Encoding (BPE) to input text:
47+
```python
48+
tokens = xlmr.encode('Hello world!')
49+
assert tokens.tolist() == [ 0, 35378, 8999, 38, 2]
50+
xlmr.decode(tokens) # 'Hello world!'
51+
```
52+
53+
##### Extract features from XLM-R:
54+
```python
55+
# Extract the last layer's features
56+
last_layer_features = xlmr.extract_features(tokens)
57+
assert last_layer_features.size() == torch.Size([1, 5, 1024])
58+
59+
# Extract all layer's features (layer 0 is the embedding layer)
60+
all_layers = xlmr.extract_features(tokens, return_all_hiddens=True)
61+
assert len(all_layers) == 25
62+
assert torch.all(all_layers[-1] == last_layer_features)
63+
```
64+
65+
## Citation
66+
67+
```bibtex
68+
@article{,
69+
title = {Unsupervised Cross-lingual Representation Learning at Scale},
70+
author = {Alexis Conneau and Kartikay Khandelwal and Naman Goyal
71+
and Vishrav Chaudhary and Guillaume Wenzek and Francisco Guzm\'an
72+
and Edouard Grave and Myle Ott and Luke Zettlemoyer and Veselin Stoyanov
73+
},
74+
journal={},
75+
year = {2019},
76+
}
77+
```

fairseq/models/roberta/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,30 @@ def upgrade_state_dict_named(self, state_dict, name):
194194
state_dict[prefix + 'classification_heads.' + k] = v
195195

196196

197+
@register_model('xlmr')
198+
class XLMRModel(RobertaModel):
199+
@classmethod
200+
def hub_models(cls):
201+
return {
202+
'xlmr.base.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz',
203+
'xlmr.large.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz',
204+
}
205+
206+
@classmethod
207+
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
208+
from fairseq import hub_utils
209+
x = hub_utils.from_pretrained(
210+
model_name_or_path,
211+
checkpoint_file,
212+
data_name_or_path,
213+
archive_map=cls.hub_models(),
214+
bpe=bpe,
215+
load_checkpoint_heads=True,
216+
**kwargs,
217+
)
218+
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
219+
220+
197221
class RobertaLMHead(nn.Module):
198222
"""Head for masked language modeling."""
199223

0 commit comments

Comments
 (0)