alisawuffles commited on
Commit
34b857f
1 Parent(s): 5328f6b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  widget:
3
- - text: "I almost forgot to eat lunch. </s></s> I didn't forget to eat lunch."
4
- - text: "I believe I will get into UW. </s></s> I will get into UW."
5
  ---
6
 
7
  This is an off-the-shelf roberta-large model finetuned on WANLI, the Worker-AI Collaborative NLI dataset ([Liu et al., 2022](https://arxiv.org/abs/2201.05955)). It outperforms the `roberta-large-mnli` model on seven out-of-domain test sets, including by 11% on HANS and 9% on Adversarial NLI.
@@ -13,7 +13,7 @@ from transformers import RobertaTokenizer, RobertaForSequenceClassification
13
  model = RobertaForSequenceClassification.from_pretrained('alisawuffles/roberta-large-wanli')
14
  tokenizer = RobertaTokenizer.from_pretrained('alisawuffles/roberta-large-wanli')
15
 
16
- x = tokenizer("I believe I will get into UW.", "I will get into UW.", hypothesis, return_tensors='pt', max_length=128, truncation=True)
17
  logits = model(**x).logits
18
  probs = logits.softmax(dim=1).squeeze(0)
19
  label_id = torch.argmax(probs).item()
 
1
  ---
2
  widget:
3
+ - text: "I almost forgot to eat lunch.</s></s>I didn't forget to eat lunch."
4
+ - text: "I almost forgot to eat lunch.</s></s>I forgot to eat lunch."
5
  ---
6
 
7
  This is an off-the-shelf roberta-large model finetuned on WANLI, the Worker-AI Collaborative NLI dataset ([Liu et al., 2022](https://arxiv.org/abs/2201.05955)). It outperforms the `roberta-large-mnli` model on seven out-of-domain test sets, including by 11% on HANS and 9% on Adversarial NLI.
 
13
  model = RobertaForSequenceClassification.from_pretrained('alisawuffles/roberta-large-wanli')
14
  tokenizer = RobertaTokenizer.from_pretrained('alisawuffles/roberta-large-wanli')
15
 
16
+ x = tokenizer("I almost forgot to eat lunch.", "I didn't forget to eat lunch.", hypothesis, return_tensors='pt', max_length=128, truncation=True)
17
  logits = model(**x).logits
18
  probs = logits.softmax(dim=1).squeeze(0)
19
  label_id = torch.argmax(probs).item()