MoritzLaurer
HF staff
Add evaluation results on the plain_text config and dev_r1 split of anli (#6)
6393711
language: | |
- en | |
license: mit | |
tags: | |
- text-classification | |
- zero-shot-classification | |
metrics: | |
- accuracy | |
datasets: | |
- multi_nli | |
- anli | |
- fever | |
pipeline_tag: zero-shot-classification | |
model-index: | |
- name: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli | |
results: | |
- task: | |
type: natural-language-inference | |
name: Natural Language Inference | |
dataset: | |
name: anli | |
type: anli | |
config: plain_text | |
split: test_r3 | |
metrics: | |
- name: Accuracy | |
type: accuracy | |
value: 0.495 | |
verified: true | |
- name: Precision Macro | |
type: precision | |
value: 0.4984740618243923 | |
verified: true | |
- name: Precision Micro | |
type: precision | |
value: 0.495 | |
verified: true | |
- name: Precision Weighted | |
type: precision | |
value: 0.4984357572868885 | |
verified: true | |
- name: Recall Macro | |
type: recall | |
value: 0.49461028192371476 | |
verified: true | |
- name: Recall Micro | |
type: recall | |
value: 0.495 | |
verified: true | |
- name: Recall Weighted | |
type: recall | |
value: 0.495 | |
verified: true | |
- name: F1 Macro | |
type: f1 | |
value: 0.4942810999491704 | |
verified: true | |
- name: F1 Micro | |
type: f1 | |
value: 0.495 | |
verified: true | |
- name: F1 Weighted | |
type: f1 | |
value: 0.4944671868893595 | |
verified: true | |
- name: loss | |
type: loss | |
value: 1.8788293600082397 | |
verified: true | |
- task: | |
type: natural-language-inference | |
name: Natural Language Inference | |
dataset: | |
name: anli | |
type: anli | |
config: plain_text | |
split: test_r1 | |
metrics: | |
- name: Accuracy | |
type: accuracy | |
value: 0.712 | |
verified: true | |
- name: Precision Macro | |
type: precision | |
value: 0.7134839439315348 | |
verified: true | |
- name: Precision Micro | |
type: precision | |
value: 0.712 | |
verified: true | |
- name: Precision Weighted | |
type: precision | |
value: 0.7134676028447461 | |
verified: true | |
- name: Recall Macro | |
type: recall | |
value: 0.7119814425203647 | |
verified: true | |
- name: Recall Micro | |
type: recall | |
value: 0.712 | |
verified: true | |
- name: Recall Weighted | |
type: recall | |
value: 0.712 | |
verified: true | |
- name: F1 Macro | |
type: f1 | |
value: 0.7119226991285647 | |
verified: true | |
- name: F1 Micro | |
type: f1 | |
value: 0.712 | |
verified: true | |
- name: F1 Weighted | |
type: f1 | |
value: 0.7119242267218338 | |
verified: true | |
- name: loss | |
type: loss | |
value: 1.0105403661727905 | |
verified: true | |
- task: | |
type: natural-language-inference | |
name: Natural Language Inference | |
dataset: | |
name: multi_nli | |
type: multi_nli | |
config: default | |
split: validation_mismatched | |
metrics: | |
- name: Accuracy | |
type: accuracy | |
value: 0.902766476810415 | |
verified: true | |
- name: Precision Macro | |
type: precision | |
value: 0.9023816542652491 | |
verified: true | |
- name: Precision Micro | |
type: precision | |
value: 0.902766476810415 | |
verified: true | |
- name: Precision Weighted | |
type: precision | |
value: 0.9034597464719761 | |
verified: true | |
- name: Recall Macro | |
type: recall | |
value: 0.9024304801555488 | |
verified: true | |
- name: Recall Micro | |
type: recall | |
value: 0.902766476810415 | |
verified: true | |
- name: Recall Weighted | |
type: recall | |
value: 0.902766476810415 | |
verified: true | |
- name: F1 Macro | |
type: f1 | |
value: 0.9023086094638595 | |
verified: true | |
- name: F1 Micro | |
type: f1 | |
value: 0.902766476810415 | |
verified: true | |
- name: F1 Weighted | |
type: f1 | |
value: 0.9030161011457231 | |
verified: true | |
- name: loss | |
type: loss | |
value: 0.3283354640007019 | |
verified: true | |
- task: | |
type: natural-language-inference | |
name: Natural Language Inference | |
dataset: | |
name: anli | |
type: anli | |
config: plain_text | |
split: dev_r1 | |
metrics: | |
- name: Accuracy | |
type: accuracy | |
value: 0.737 | |
verified: true | |
- name: Precision Macro | |
type: precision | |
value: 0.737681071614645 | |
verified: true | |
- name: Precision Micro | |
type: precision | |
value: 0.737 | |
verified: true | |
- name: Precision Weighted | |
type: precision | |
value: 0.7376755842752241 | |
verified: true | |
- name: Recall Macro | |
type: recall | |
value: 0.7369675064285843 | |
verified: true | |
- name: Recall Micro | |
type: recall | |
value: 0.737 | |
verified: true | |
- name: Recall Weighted | |
type: recall | |
value: 0.737 | |
verified: true | |
- name: F1 Macro | |
type: f1 | |
value: 0.7366853496239583 | |
verified: true | |
- name: F1 Micro | |
type: f1 | |
value: 0.737 | |
verified: true | |
- name: F1 Weighted | |
type: f1 | |
value: 0.7366990292378379 | |
verified: true | |
- name: loss | |
type: loss | |
value: 0.9349392056465149 | |
verified: true | |
# DeBERTa-v3-base-mnli-fever-anli | |
## Model description | |
This model was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs. This base model outperforms almost all large models on the [ANLI benchmark](https://github.com/facebookresearch/anli). | |
The base model is [DeBERTa-v3-base from Microsoft](https://huggingface.co/microsoft/deberta-v3-base). The v3 variant of DeBERTa substantially outperforms previous versions of the model by including a different pre-training objective, see annex 11 of the original [DeBERTa paper](https://arxiv.org/pdf/2006.03654.pdf). | |
For highest performance (but less speed), I recommend using https://huggingface.co/MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli. | |
### How to use the model | |
#### Simple zero-shot classification pipeline | |
```python | |
from transformers import pipeline | |
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli") | |
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU" | |
candidate_labels = ["politics", "economy", "entertainment", "environment"] | |
output = classifier(sequence_to_classify, candidate_labels, multi_label=False) | |
print(output) | |
``` | |
#### NLI use-case | |
```python | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing." | |
hypothesis = "The movie was good." | |
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt") | |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu" | |
prediction = torch.softmax(output["logits"][0], -1).tolist() | |
label_names = ["entailment", "neutral", "contradiction"] | |
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)} | |
print(prediction) | |
``` | |
### Training data | |
DeBERTa-v3-base-mnli-fever-anli was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs. | |
### Training procedure | |
DeBERTa-v3-base-mnli-fever-anli was trained using the Hugging Face trainer with the following hyperparameters. | |
``` | |
training_args = TrainingArguments( | |
num_train_epochs=3, # total number of training epochs | |
learning_rate=2e-05, | |
per_device_train_batch_size=32, # batch size per device during training | |
per_device_eval_batch_size=32, # batch size for evaluation | |
warmup_ratio=0.1, # number of warmup steps for learning rate scheduler | |
weight_decay=0.06, # strength of weight decay | |
fp16=True # mixed precision training | |
) | |
``` | |
### Eval results | |
The model was evaluated using the test sets for MultiNLI and ANLI and the dev set for Fever-NLI. The metric used is accuracy. | |
mnli-m | mnli-mm | fever-nli | anli-all | anli-r3 | |
---------|----------|---------|----------|---------- | |
0.903 | 0.903 | 0.777 | 0.579 | 0.495 | |
## Limitations and bias | |
Please consult the original DeBERTa paper and literature on different NLI datasets for potential biases. | |
## Citation | |
If you use this model, please cite: Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k. | |
### Ideas for cooperation or questions? | |
If you have questions or ideas for cooperation, contact me at m{dot}laurer{at}vu{dot}nl or [LinkedIn](https://www.linkedin.com/in/moritz-laurer/) | |
### Debugging and issues | |
Note that DeBERTa-v3 was released on 06.12.21 and older versions of HF Transformers seem to have issues running the model (e.g. resulting in an issue with the tokenizer). Using Transformers>=4.13 might solve some issues. | |