ENLI (Explainable/Interpretable Natural Language Inference)
Collection
1 item
•
Updated
This is a fine-tuned version of backpack-gpt2 with a NLI classification head on the esnli dataset. Results:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
concatenated_sentences = [f'{premise.strip(".")}. ^ {hypothesis.strip(".")}.' for premise, hypothesis in zip(examples['premise'], examples['hypothesis'])]
tokenized_inputs = tokenizer(
concatenated_sentences,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
return tokenized_inputs
model = AutoModelForSequenceClassification.from_pretrained('ErfanMoosaviMonazzah/backpack-gpt2-nli', trust_remote_code=True)
model.eval()
tokenized_sent = tokenize_function({
'premise':['A boy is jumping on skateboard in the middle of a red bridge.',
'Two women who just had lunch hugging and saying goodbye.',
'Children smiling and waving at camera'],
'hypothesis':['The boy does a skateboarding trick.',
'The friends have just met for the first time in 20 years, and have had a great time catching up.',
'The kids are frowning']
})
model.predict(input_ids=tokenized_sent['input_ids'], attention_mask=tokenized_sent['attention_mask'])
The following hyperparameters were used during training:
Step | Training Loss | Validation Loss | Precision | Recall | F1 | Accuracy |
---|---|---|---|---|---|---|
512 | 0.614900 | 0.463713 | 0.826792 | 0.824639 | 0.825133 | 0.824731 |
1024 | 0.503300 | 0.431796 | 0.844831 | 0.839414 | 0.839980 | 0.839565 |
1536 | 0.475600 | 0.400771 | 0.848741 | 0.847009 | 0.846287 | 0.847795 |
2048 | 0.455900 | 0.375981 | 0.859064 | 0.857357 | 0.857749 | 0.857448 |
2560 | 0.440400 | 0.365537 | 0.862000 | 0.862078 | 0.861917 | 0.862426 |
3072 | 0.433100 | 0.365180 | 0.864717 | 0.859693 | 0.860237 | 0.859785 |
3584 | 0.425100 | 0.346340 | 0.872312 | 0.870635 | 0.870865 | 0.870961 |
4096 | 0.413300 | 0.343761 | 0.873606 | 0.873046 | 0.873174 | 0.873298 |
4608 | 0.412000 | 0.344890 | 0.882609 | 0.882120 | 0.882255 | 0.882341 |
5120 | 0.402600 | 0.336744 | 0.876463 | 0.875629 | 0.875827 | 0.875737 |
5632 | 0.390600 | 0.323248 | 0.882598 | 0.880779 | 0.881129 | 0.880817 |
6144 | 0.388300 | 0.338029 | 0.877255 | 0.877041 | 0.877126 | 0.877261 |
6656 | 0.390800 | 0.333301 | 0.876357 | 0.876362 | 0.875965 | 0.876753 |
7168 | 0.383800 | 0.328297 | 0.883593 | 0.883675 | 0.883629 | 0.883967 |
7680 | 0.380800 | 0.331854 | 0.882362 | 0.880373 | 0.880764 | 0.880512 |
8192 | 0.368400 | 0.323076 | 0.881730 | 0.881378 | 0.881419 | 0.881528 |
8704 | 0.367000 | 0.313959 | 0.889204 | 0.889047 | 0.889053 | 0.889352 |
9216 | 0.315600 | 0.333637 | 0.885518 | 0.883965 | 0.884266 | 0.883967 |
9728 | 0.303100 | 0.319416 | 0.888667 | 0.888092 | 0.888256 | 0.888234 |
10240 | 0.307200 | 0.317827 | 0.887575 | 0.887647 | 0.887418 | 0.888031 |
10752 | 0.300100 | 0.311810 | 0.890908 | 0.890827 | 0.890747 | 0.891181 |
11264 | 0.303400 | 0.311010 | 0.889871 | 0.887939 | 0.888309 | 0.887929 |
11776 | 0.300500 | 0.309282 | 0.891041 | 0.889819 | 0.890077 | 0.889860 |
12288 | 0.303600 | 0.326918 | 0.891272 | 0.891250 | 0.890942 | 0.891689 |
12800 | 0.300300 | 0.301688 | 0.894516 | 0.894619 | 0.894481 | 0.894940 |
13312 | 0.302200 | 0.302173 | 0.896441 | 0.896527 | 0.896462 | 0.896769 |
13824 | 0.299800 | 0.293489 | 0.895047 | 0.895172 | 0.895084 | 0.895448 |
14336 | 0.294600 | 0.297645 | 0.895865 | 0.896012 | 0.895886 | 0.896261 |
14848 | 0.296700 | 0.300751 | 0.895277 | 0.895401 | 0.895304 | 0.895651 |
15360 | 0.293100 | 0.293049 | 0.896855 | 0.896705 | 0.896757 | 0.896871 |
15872 | 0.293600 | 0.294201 | 0.895933 | 0.895557 | 0.895624 | 0.895651 |
16384 | 0.290100 | 0.289367 | 0.897847 | 0.897889 | 0.897840 | 0.898090 |
16896 | 0.293600 | 0.283990 | 0.898889 | 0.898724 | 0.898789 | 0.898903 |
17408 | 0.285800 | 0.308257 | 0.898250 | 0.898102 | 0.898162 | 0.898293 |
17920 | 0.252400 | 0.327164 | 0.898860 | 0.898807 | 0.898831 | 0.899004 |
18432 | 0.219500 | 0.315286 | 0.898877 | 0.898835 | 0.898831 | 0.899004 |
18944 | 0.217900 | 0.312738 | 0.898857 | 0.898958 | 0.898886 | 0.899207 |
19456 | 0.186400 | 0.320669 | 0.899252 | 0.899166 | 0.899194 | 0.899411 |
19968 | 0.199000 | 0.316840 | 0.900458 | 0.900455 | 0.900426 | 0.900630 |