lannelin's picture
Update README.md (#1)
2032d79
---
language:
- multilingual
- en
- fr
- es
- de
- el
- bg
- ru
- tr
- ar
- vi
- th
- zh
- hi
- sw
- ur
tags:
- pytorch
license: apache-2.0
datasets:
- multi_nli
- xnli
metrics:
- xnli
---
# mt5-large-finetuned-mnli-xtreme-xnli
## Model Description
This model takes a pretrained large [multilingual-t5](https://github.com/google-research/multilingual-t5) (also available from [models](https://huggingface.co/google/mt5-large)) and fine-tunes it on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set. It is intended to be used for zero-shot text classification, inspired by [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli).
## Intended Use
This model is intended to be used for zero-shot text classification, especially in languages other than English. It is fine-tuned on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set, a multilingual NLI dataset. The model can therefore be used with any of the languages in the XNLI corpus:
- Arabic
- Bulgarian
- Chinese
- English
- French
- German
- Greek
- Hindi
- Russian
- Spanish
- Swahili
- Thai
- Turkish
- Urdu
- Vietnamese
As per recommendations in [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli), for English-only classification, you might want to check out:
- [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli)
- [a distilled bart MNLI model](https://huggingface.co/models?filter=pipeline_tag%3Azero-shot-classification&search=valhalla).
### Zero-shot example:
The model retains its text-to-text characteristic after fine-tuning. This means that our expected outputs will be text. During fine-tuning, the model learns to respond to the NLI task with a series of single token responses that map to entailment, neutral, or contradiction. The NLI task is indicated with a fixed prefix, "xnli:".
Below is an example, using PyTorch, of the model's use in a similar fashion to the `zero-shot-classification` pipeline. We use the logits from the LM output at the first token to represent confidence.
```python
from torch.nn.functional import softmax
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)
model.eval()
sequence_to_classify = "¿A quién vas a votar en 2020?"
candidate_labels = ["Europa", "salud pública", "política"]
hypothesis_template = "Este ejemplo es {}."
ENTAILS_LABEL = "▁0"
NEUTRAL_LABEL = "▁1"
CONTRADICTS_LABEL = "▁2"
label_inds = tokenizer.convert_tokens_to_ids(
[ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
def process_nli(premise: str, hypothesis: str):
""" process to required xnli format with task prefix """
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
# construct sequence of premise, hypothesis pairs
pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in
candidate_labels]
# format for mt5 xnli task
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for
premise, hypothesis in pairs]
print(seqs)
# ['xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es Europa.',
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es salud pública.',
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política.']
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
num_beams=1)
# sanity check that our sequences are expected length (1 + start token + end token = 3)
for i, seq in enumerate(out.sequences):
assert len(
seq) == 3, f"generated sequence {i} not of expected length, 3." \\\\
f" Actual length: {len(seq)}"
# get the scores for our only token of interest
# we'll now treat these like the output logits of a `*ForSequenceClassification` model
scores = out.scores[0]
# scores has a size of the model's vocab.
# However, for this task we have a fixed set of labels
# sanity check that these labels are always the top 3 scoring
for i, sequence_scores in enumerate(scores):
top_scores = sequence_scores.argsort()[-3:]
assert set(top_scores.tolist()) == set(label_inds), \\\\
f"top scoring tokens are not expected for this task." \\\\
f" Expected: {label_inds}. Got: {top_scores.tolist()}."
# cut down scores to our task labels
scores = scores[:, label_inds]
print(scores)
# tensor([[-2.5697, 1.0618, 0.2088],
# [-5.4492, -2.1805, -0.1473],
# [ 2.2973, 3.7595, -0.1769]])
# new indices of entailment and contradiction in scores
entailment_ind = 0
contradiction_ind = 2
# we can show, per item, the entailment vs contradiction probas
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
print(entail_vs_contra_probas)
# tensor([[0.0585, 0.9415],
# [0.0050, 0.9950],
# [0.9223, 0.0777]])
# or we can show probas similar to `ZeroShotClassificationPipeline`
# this gives a zero-shot classification style output across labels
entail_scores = scores[:, entailment_ind]
entail_probas = softmax(entail_scores, dim=0)
print(entail_probas)
# tensor([7.6341e-03, 4.2873e-04, 9.9194e-01])
print(dict(zip(candidate_labels, entail_probas.tolist())))
# {'Europa': 0.007634134963154793,
# 'salud pública': 0.0004287279152777046,
# 'política': 0.9919371604919434}
```
Unfortunately, the `generate` function for the TF equivalent model doesn't exactly mirror the PyTorch version so the above code won't directly transfer.
The model is currently not compatible with the existing `zero-shot-classification` pipeline.
## Training
This model was pre-trained on a set of 101 languages in the mC4, as described in [the mt5 paper](https://arxiv.org/abs/2010.11934). It was then fine-tuned on the [mt5_xnli_translate_train](https://github.com/google-research/multilingual-t5/blob/78d102c830d76bd68f27596a97617e2db2bfc887/multilingual_t5/tasks.py#L190) task for 8k steps in a similar manner to that described in the [offical repo](https://github.com/google-research/multilingual-t5#fine-tuning), with guidance from [Stephen Mayhew's notebook](https://github.com/mayhewsw/multilingual-t5/blob/master/notebooks/mt5-xnli.ipynb). The resulting model was then converted to :hugging_face: format.
## Eval results
Accuracy over XNLI test set:
| ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vi | zh | average |
|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|
| 81.0 | 85.0 | 84.3 | 84.3 | 88.8 | 85.3 | 83.9 | 79.9 | 82.6 | 78.0 | 81.0 | 81.6 | 76.4 | 81.7 | 82.3 | 82.4 |