Commit
·
7de8833
1
Parent(s):
d0c106c
Update README.md
Browse files
README.md
CHANGED
@@ -42,16 +42,21 @@ As of December 2021, mDeBERTa-base is the best performing multilingual transform
|
|
42 |
```python
|
43 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
44 |
import torch
|
|
|
45 |
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
|
46 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
47 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
|
48 |
premise = "Angela Merkel ist eine Politikerin in Deutschland und Vorsitzende der CDU"
|
49 |
hypothesis = "Emmanuel Macron is the President of France"
|
|
|
50 |
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
|
51 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
52 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
|
|
53 |
label_names = ["entailment", "neutral", "contradiction"]
|
54 |
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
|
|
|
55 |
print(prediction)
|
56 |
```
|
57 |
|
|
|
42 |
```python
|
43 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
44 |
import torch
|
45 |
+
|
46 |
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
|
47 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
48 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
49 |
+
|
50 |
premise = "Angela Merkel ist eine Politikerin in Deutschland und Vorsitzende der CDU"
|
51 |
hypothesis = "Emmanuel Macron is the President of France"
|
52 |
+
|
53 |
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
|
54 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
55 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
56 |
+
|
57 |
label_names = ["entailment", "neutral", "contradiction"]
|
58 |
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
|
59 |
+
|
60 |
print(prediction)
|
61 |
```
|
62 |
|