jessicayjm
commited on
Commit
•
6334b7a
1
Parent(s):
1b7fa3a
Add model usage code
Browse files
README.md
CHANGED
@@ -14,7 +14,7 @@ The model classifies an *appraisal* given a sentence and is trained on [ALOE](ht
|
|
14 |
|
15 |
**Output:** logits (in order of labels)
|
16 |
|
17 |
-
**Model architecture**:
|
18 |
|
19 |
**Developed by:** Jiamin Yang
|
20 |
|
@@ -48,8 +48,12 @@ from openprompt.plms import load_plm
|
|
48 |
from openprompt.prompts import ManualTemplate
|
49 |
from openprompt.prompts import ManualVerbalizer
|
50 |
from openprompt import PromptForClassification
|
|
|
|
|
51 |
|
52 |
-
|
|
|
|
|
53 |
|
54 |
plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
|
55 |
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
|
@@ -67,5 +71,34 @@ state_dict = checkpoint['model_state_dict']
|
|
67 |
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']
|
68 |
|
69 |
prompt_model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
```
|
71 |
|
|
|
14 |
|
15 |
**Output:** logits (in order of labels)
|
16 |
|
17 |
+
**Model architecture**: OpenPrompt_+RoBERTa
|
18 |
|
19 |
**Developed by:** Jiamin Yang
|
20 |
|
|
|
48 |
from openprompt.prompts import ManualTemplate
|
49 |
from openprompt.prompts import ManualVerbalizer
|
50 |
from openprompt import PromptForClassification
|
51 |
+
from openprompt.data_utils import InputExample
|
52 |
+
from openprompt import PromptDataLoader
|
53 |
|
54 |
+
|
55 |
+
torch.cuda.set_device(1)
|
56 |
+
checkpoint_file = 'upload_version/empathy-appraisal-span.pt'
|
57 |
|
58 |
plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
|
59 |
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
|
|
|
71 |
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']
|
72 |
|
73 |
prompt_model.load_state_dict(state_dict)
|
74 |
+
|
75 |
+
# use the model
|
76 |
+
dataset = [
|
77 |
+
InputExample(
|
78 |
+
guid = 0,
|
79 |
+
text_a = "I am sorry for your loss",
|
80 |
+
),
|
81 |
+
InputExample(
|
82 |
+
guid = 1,
|
83 |
+
text_a = "It's not your fault",
|
84 |
+
),
|
85 |
+
]
|
86 |
+
|
87 |
+
data_loader = PromptDataLoader(dataset=dataset,
|
88 |
+
template=template,
|
89 |
+
tokenizer=tokenizer,
|
90 |
+
tokenizer_wrapper_class=WrapperClass,
|
91 |
+
max_seq_length=512,
|
92 |
+
batch_size=2,
|
93 |
+
shuffle=False,
|
94 |
+
teacher_forcing=False,
|
95 |
+
predict_eos_token=False,
|
96 |
+
truncate_method='head')
|
97 |
+
prompt_model.eval()
|
98 |
+
with torch.no_grad():
|
99 |
+
for batch in data_loader:
|
100 |
+
logits = prompt_model(batch.to('cuda'))
|
101 |
+
preds = torch.argmax(logits, dim = -1)
|
102 |
+
print(preds) #[8, 5]
|
103 |
```
|
104 |
|