Spaces:
Runtime error
Runtime error
pminervini
commited on
Commit
•
10f3d68
1
Parent(s):
06827ad
update
Browse files- cli/halueval-cli.py +1 -1
- requirements.txt +1 -0
- src/backend/tasks/xsum/task.py +15 -1
cli/halueval-cli.py
CHANGED
@@ -46,7 +46,7 @@ def main():
|
|
46 |
|
47 |
for task in TASKS_HARNESS:
|
48 |
print(f"Selected Tasks: [{task}]")
|
49 |
-
results = evaluator.simple_evaluate(model="hf", model_args=eval_request.get_model_args(), tasks=[task.benchmark], num_fewshot=
|
50 |
batch_size=1, device="mps", use_cache=None, limit=10, write_out=True)
|
51 |
print('AAA', results["results"])
|
52 |
|
|
|
46 |
|
47 |
for task in TASKS_HARNESS:
|
48 |
print(f"Selected Tasks: [{task}]")
|
49 |
+
results = evaluator.simple_evaluate(model="hf", model_args=eval_request.get_model_args(), tasks=[task.benchmark], num_fewshot=4,
|
50 |
batch_size=1, device="mps", use_cache=None, limit=10, write_out=True)
|
51 |
print('AAA', results["results"])
|
52 |
|
requirements.txt
CHANGED
@@ -25,3 +25,4 @@ sacrebleu
|
|
25 |
cchardet
|
26 |
rouge_score
|
27 |
bert-score
|
|
|
|
25 |
cchardet
|
26 |
rouge_score
|
27 |
bert-score
|
28 |
+
evaluate
|
src/backend/tasks/xsum/task.py
CHANGED
@@ -61,6 +61,7 @@ class XSum(Task):
|
|
61 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
62 |
self.factkb_tokenizer = None
|
63 |
self.factkb_model = None
|
|
|
64 |
|
65 |
def maybe_init_factkb(self):
|
66 |
if self.factkb_tokenizer is None or self.factkb_model is None:
|
@@ -68,6 +69,11 @@ class XSum(Task):
|
|
68 |
self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
|
69 |
self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
def has_training_docs(self):
|
72 |
return True
|
73 |
|
@@ -126,6 +132,8 @@ class XSum(Task):
|
|
126 |
completion = results[0]
|
127 |
|
128 |
document = doc["document"]
|
|
|
|
|
129 |
true_refs = [doc["summary"]]
|
130 |
all_refs = true_refs
|
131 |
|
@@ -144,11 +152,17 @@ class XSum(Task):
|
|
144 |
factkb_logits = self.factkb_model(**factkb_tokens).logits
|
145 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
146 |
|
|
|
|
|
|
|
147 |
res = {
|
148 |
"rouge1": rouge1_scores[0],
|
149 |
"rouge2": rouge2_scores[0],
|
150 |
"rougeL": rougeL_scores[0],
|
151 |
-
"factKB": float(factkb_res[0][1])
|
|
|
|
|
|
|
152 |
}
|
153 |
|
154 |
# breakpoint()
|
|
|
61 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
62 |
self.factkb_tokenizer = None
|
63 |
self.factkb_model = None
|
64 |
+
self.bert_score = None
|
65 |
|
66 |
def maybe_init_factkb(self):
|
67 |
if self.factkb_tokenizer is None or self.factkb_model is None:
|
|
|
69 |
self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
|
70 |
self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
|
71 |
|
72 |
+
def maybe_init_bertscore(self):
|
73 |
+
if self.bert_score is None:
|
74 |
+
from evaluate import load
|
75 |
+
self.bert_score = load("bertscore")
|
76 |
+
|
77 |
def has_training_docs(self):
|
78 |
return True
|
79 |
|
|
|
132 |
completion = results[0]
|
133 |
|
134 |
document = doc["document"]
|
135 |
+
gold_summary = doc["summary"]
|
136 |
+
|
137 |
true_refs = [doc["summary"]]
|
138 |
all_refs = true_refs
|
139 |
|
|
|
152 |
factkb_logits = self.factkb_model(**factkb_tokens).logits
|
153 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
154 |
|
155 |
+
self.maybe_init_factkb()
|
156 |
+
bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], lang="en")
|
157 |
+
|
158 |
res = {
|
159 |
"rouge1": rouge1_scores[0],
|
160 |
"rouge2": rouge2_scores[0],
|
161 |
"rougeL": rougeL_scores[0],
|
162 |
+
"factKB": float(factkb_res[0][1]),
|
163 |
+
"bertscore_precision": float(bert_score_res["precision"][0]),
|
164 |
+
"bertscore_recall": float(bert_score_res["recall"][0]),
|
165 |
+
"bertscore_f1": float(bert_score_res["f1"][0]),
|
166 |
}
|
167 |
|
168 |
# breakpoint()
|