facat commited on
Commit
be1543a
1 Parent(s): 044ed98

add mmlu and cmmlu

Browse files
Files changed (3) hide show
  1. tasks.py +364 -49
  2. tlem.py +48 -48
  3. utils.py +37 -9
tasks.py CHANGED
@@ -3,19 +3,14 @@ from datasets import load_dataset, Dataset
3
  from functools import cached_property
4
  from tqdm.auto import tqdm
5
  from typing import Any, Optional, Protocol, Iterable, Callable
 
 
 
6
 
7
- from .utils import (
8
- NUMERIC_IN_ZH,
9
- extract_choice_ans,
10
- extract_numeric,
11
- get_answer,
12
- is_equiv,
13
- )
14
 
15
  from evaluate import load
16
 
17
- TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
18
-
19
 
20
  def fake_pipeline(prompts: Iterable[str]) -> list[str]:
21
  return [prompt for prompt in tqdm(prompts)]
@@ -30,14 +25,25 @@ class Task:
30
  input_column: str = "question"
31
  label_column: str = "answer"
32
  prompt: Optional[Callable | str] = None
 
 
 
33
 
34
- @cached_property
35
- def name(self):
36
- return (
37
- self.dataset_name
38
  if isinstance(self.dataset_name, str)
39
- else self.dataset_name[0]
40
- ) + f"-{self.split}"
 
 
 
 
 
 
 
 
 
41
 
42
  @cached_property
43
  def samples(self):
@@ -49,20 +55,38 @@ class Task:
49
  *self.dataset_name
50
  if isinstance(self.dataset_name, tuple)
51
  else self.dataset_name,
52
- split=self.split,
53
  )
 
54
  if self.prompt is not None:
55
- ds = ds.map(
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  lambda example: {
57
- self.input_column: self.prompt.format(
58
- input_column=example[self.input_column]
59
- )
 
 
 
 
 
 
60
  }
61
- if isinstance(self.prompt, str)
62
- else self.prompt(example),
63
  )
64
 
65
- return ds
66
 
67
  @cached_property
68
  def metric(self):
@@ -73,14 +97,44 @@ class Task:
73
  )
74
  return metric
75
 
76
- def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
77
- outputs = pipeline(self.samples)
78
- return self.metric.compute(
79
- responses=outputs, references=self.dataset[self.label_column]
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  class Metrics:
 
 
 
84
  def gsm8k(responses: list[str], answers: list[str | int]):
85
  scores = []
86
  for response, answer in zip(responses, answers):
@@ -112,26 +166,287 @@ class Metrics:
112
  scores.append(1.0 * (pred == gold))
113
  return scores
114
 
115
- def gsm8k_zh(responses: list[str], answers: list[str]):
116
- scores = []
117
- for response, answer in zip(responses, answers):
118
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
119
- gold = extract_numeric(answer)
120
- scores.append(1.0 * (pred == gold))
121
- return scores
122
 
123
- def svamp(responses: list[float], answers: list[str]):
124
- scores = []
125
- for response, answer in zip(responses, answers):
126
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
127
- gold = answer
128
- scores.append(1.0 * (float(pred) == gold))
129
- return scores
130
 
131
- def mmlu(responses, answers):
132
- scores = []
133
- for response, answer in zip(responses, answers):
134
- pred = extract_choice_ans(response)
135
- gold = answer.lower()
136
- scores.append(1.0 * (pred == gold))
137
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from functools import cached_property
4
  from tqdm.auto import tqdm
5
  from typing import Any, Optional, Protocol, Iterable, Callable
6
+ import logging
7
+ import pandas as pd
8
+ from functools import partial
9
 
10
+ from .utils import *
 
 
 
 
 
 
11
 
12
  from evaluate import load
13
 
 
 
14
 
15
  def fake_pipeline(prompts: Iterable[str]) -> list[str]:
16
  return [prompt for prompt in tqdm(prompts)]
 
25
  input_column: str = "question"
26
  label_column: str = "answer"
27
  prompt: Optional[Callable | str] = None
28
+ few_shot: int = 0
29
+ few_shot_from: Optional[str] = None
30
+ # results: dict[str, Any] = field(default_factory=dict)
31
 
32
+ def __post_init__(self):
33
+ names = (
34
+ [self.dataset_name]
 
35
  if isinstance(self.dataset_name, str)
36
+ else list(self.dataset_name)
37
+ )
38
+ names[0] = names[0].split("/")[-1]
39
+
40
+ self.name = "-".join(names) + f"-{self.split}"
41
+ if isinstance(self.prompt, str):
42
+ self.prompt = lambda example: {
43
+ self.input_column: self.prompt.format(
44
+ input_column=example[self.input_column]
45
+ )
46
+ }
47
 
48
  @cached_property
49
  def samples(self):
 
55
  *self.dataset_name
56
  if isinstance(self.dataset_name, tuple)
57
  else self.dataset_name,
58
+ # split=self.split,
59
  )
60
+ test_ds = ds[self.split]
61
  if self.prompt is not None:
62
+ test_ds = test_ds.map(self.prompt)
63
+
64
+ if self.few_shot:
65
+ if self.few_shot_from is None:
66
+ for name in ["train", "validation", "val", "dev"]:
67
+ if name in ds:
68
+ self.few_shot_from = name
69
+ break
70
+
71
+ shots = ds[self.few_shot_from].select(range(self.few_shot))
72
+ if self.prompt is not None:
73
+ shots = shots.map(self.prompt)
74
+
75
+ shots = shots.map(
76
  lambda example: {
77
+ self.input_column: example[self.input_column]
78
+ + example[self.label_column],
79
+ }
80
+ )[self.input_column]
81
+ few_shot_prompts = "\n".join(shots)
82
+
83
+ test_ds = test_ds.map(
84
+ lambda example: {
85
+ self.input_column: few_shot_prompts + example[self.input_column],
86
  }
 
 
87
  )
88
 
89
+ return test_ds
90
 
91
  @cached_property
92
  def metric(self):
 
97
  )
98
  return metric
99
 
100
+ def run(
101
+ self,
102
+ pipeline,
103
+ ):
104
+ if (outputs := pipeline(self.samples)) is None:
105
+ logging.warning("pipeline returns None")
106
+ return
107
+ self.outputs = outputs
108
+ try:
109
+ result = self.metric._compute(
110
+ responses=outputs, references=self.dataset[self.label_column]
111
+ )
112
+ except Exception as e:
113
+ result = self.metric.compute(
114
+ responses=outputs, references=self.dataset[self.label_column]
115
+ )
116
+ # if log:
117
+ # name = name or pipeline.__name__
118
+ # self.results[name] = result
119
+
120
+ return result
121
+
122
+
123
+ def multichoice(responses: Any, references: list[str]):
124
+ if isinstance(responses[0], str):
125
+ responses = [extract_choice(response) for response in responses]
126
+ else:
127
+ responses = decode_choice(responses)
128
+
129
+ return [
130
+ int(response == reference) for reference, response in zip(references, responses)
131
+ ]
132
 
133
 
134
  class Metrics:
135
+ cmmlu = multichoice
136
+ mmlu = multichoice
137
+
138
  def gsm8k(responses: list[str], answers: list[str | int]):
139
  scores = []
140
  for response, answer in zip(responses, answers):
 
166
  scores.append(1.0 * (pred == gold))
167
  return scores
168
 
 
 
 
 
 
 
 
169
 
170
+ class CMMLU:
171
+ def prompt_cmmlu(example, chat=False):
172
+ prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:"
173
+ prompt = prefix + example["Question"]
174
+ for choice in list("ABCD"):
175
+ prompt += f"\n{choice}. {example[choice]}"
 
176
 
177
+ prompt += "\n答案:"
178
+ return {"prompt": prompt}
179
+
180
+ subcategories = {
181
+ "agronomy": ["other"],
182
+ "anatomy": ["biology"],
183
+ "ancient_chinese": ["linguistics", "china specific"],
184
+ "arts": ["arts"],
185
+ "astronomy": ["physics"],
186
+ "business_ethics": ["business"],
187
+ "chinese_civil_service_exam": ["politics", "china specific"],
188
+ "chinese_driving_rule": ["other", "china specific"],
189
+ "chinese_food_culture": ["culture", "china specific"],
190
+ "chinese_foreign_policy": ["politics", "china specific"],
191
+ "chinese_history": ["history", "china specific"],
192
+ "chinese_literature": ["literature", "china specific"],
193
+ "chinese_teacher_qualification": ["education", "china specific"],
194
+ "college_actuarial_science": ["math"],
195
+ "college_education": ["education"],
196
+ "college_engineering_hydrology": ["engineering"],
197
+ "college_law": ["law"],
198
+ "college_mathematics": ["math"],
199
+ "college_medical_statistics": ["statistics"],
200
+ "clinical_knowledge": ["other"],
201
+ "college_medicine": ["other"],
202
+ "computer_science": ["computer science"],
203
+ "computer_security": ["other"],
204
+ "conceptual_physics": ["physics"],
205
+ "construction_project_management": ["other", "china specific"],
206
+ "economics": ["economics"],
207
+ "education": ["education"],
208
+ "elementary_chinese": ["linguistics", "china specific"],
209
+ "elementary_commonsense": ["other", "china specific"],
210
+ "elementary_information_and_technology": ["other"],
211
+ "electrical_engineering": ["engineering"],
212
+ "elementary_mathematics": ["math"],
213
+ "ethnology": ["culture", "china specific"],
214
+ "food_science": ["other"],
215
+ "genetics": ["biology"],
216
+ "global_facts": ["global"],
217
+ "high_school_biology": ["biology"],
218
+ "high_school_chemistry": ["chemistry"],
219
+ "high_school_geography": ["geography"],
220
+ "high_school_mathematics": ["math"],
221
+ "high_school_physics": ["physics"],
222
+ "high_school_politics": ["politics", "china specific"],
223
+ "human_sexuality": ["other"],
224
+ "international_law": ["law"],
225
+ "journalism": ["sociology"],
226
+ "jurisprudence": ["law"],
227
+ "legal_and_moral_basis": ["other"],
228
+ "logical": ["philosophy"],
229
+ "machine_learning": ["computer science"],
230
+ "management": ["business"],
231
+ "marketing": ["business"],
232
+ "marxist_theory": ["philosophy"],
233
+ "modern_chinese": ["linguistics", "china specific"],
234
+ "nutrition": ["other"],
235
+ "philosophy": ["philosophy"],
236
+ "professional_accounting": ["business"],
237
+ "professional_law": ["law"],
238
+ "professional_medicine": ["other"],
239
+ "professional_psychology": ["psychology"],
240
+ "public_relations": ["politics"],
241
+ "security_study": ["politics"],
242
+ "sociology": ["culture"],
243
+ "sports_science": ["other"],
244
+ "traditional_chinese_medicine": ["other", "china specific"],
245
+ "virology": ["biology"],
246
+ "world_history": ["history"],
247
+ "world_religions": ["global"],
248
+ }
249
+
250
+ categories = {
251
+ "STEM": [
252
+ "physics",
253
+ "chemistry",
254
+ "biology",
255
+ "computer science",
256
+ "math",
257
+ "engineering",
258
+ "statistics",
259
+ ],
260
+ "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
261
+ "Social Science": [
262
+ "linguistics",
263
+ "business",
264
+ "politics",
265
+ "culture",
266
+ "economics",
267
+ "geography",
268
+ "psychology",
269
+ "education",
270
+ "sociology",
271
+ ],
272
+ "Other": ["other"],
273
+ "China specific": ["china specific"],
274
+ "Test": ["computer science"],
275
+ }
276
+
277
+ finer_categories = (
278
+ pd.Series(subcategories) # noqa # type: ignore
279
+ .explode()
280
+ .reset_index()
281
+ .set_index(0)
282
+ .groupby(0)
283
+ .agg(list)["index"]
284
+ .to_dict()
285
+ )
286
+
287
+ @classmethod
288
+ def suite(cls, chat=False):
289
+ suite = {}
290
+ for k, v in cls.categories.items():
291
+ for subject in v:
292
+ suite[k] = [
293
+ Task(
294
+ ("haonan-li/cmmlu", subcategories),
295
+ metric_name=("sustech/tlem", "cmmlu"),
296
+ input_column="prompt",
297
+ label_column="Answer",
298
+ prompt=partial(cls.prompt_cmmlu, chat=chat),
299
+ )
300
+ for subcategories in cls.finer_categories[subject]
301
+ ]
302
+ return suite
303
+
304
+
305
+ class MMLU:
306
+ input_column = "prompt"
307
+ label_column = "target"
308
+
309
+ @classmethod
310
+ def prompt_mmlu(cls, example, chat=False):
311
+ prefix = (
312
+ "The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n"
313
+ if chat
314
+ else "Question: "
315
+ )
316
+ prompt = prefix + example["input"]
317
+ for choice in list("ABCD"):
318
+ prompt += f"\n{choice}. {example[choice]}"
319
+
320
+ prompt += "\nAnswer:"
321
+ return {"prompt": prompt}
322
+
323
+ subcategories = {
324
+ "abstract_algebra": ["math"],
325
+ "anatomy": ["health"],
326
+ "astronomy": ["physics"],
327
+ "business_ethics": ["business"],
328
+ "clinical_knowledge": ["health"],
329
+ "college_biology": ["biology"],
330
+ "college_chemistry": ["chemistry"],
331
+ "college_computer_science": ["computer science"],
332
+ "college_mathematics": ["math"],
333
+ "college_medicine": ["health"],
334
+ "college_physics": ["physics"],
335
+ "computer_security": ["computer science"],
336
+ "conceptual_physics": ["physics"],
337
+ "econometrics": ["economics"],
338
+ "electrical_engineering": ["engineering"],
339
+ "elementary_mathematics": ["math"],
340
+ "formal_logic": ["philosophy"],
341
+ "global_facts": ["other"],
342
+ "high_school_biology": ["biology"],
343
+ "high_school_chemistry": ["chemistry"],
344
+ "high_school_computer_science": ["computer science"],
345
+ "high_school_european_history": ["history"],
346
+ "high_school_geography": ["geography"],
347
+ "high_school_government_and_politics": ["politics"],
348
+ "high_school_macroeconomics": ["economics"],
349
+ "high_school_mathematics": ["math"],
350
+ "high_school_microeconomics": ["economics"],
351
+ "high_school_physics": ["physics"],
352
+ "high_school_psychology": ["psychology"],
353
+ "high_school_statistics": ["math"],
354
+ "high_school_us_history": ["history"],
355
+ "high_school_world_history": ["history"],
356
+ "human_aging": ["health"],
357
+ "human_sexuality": ["culture"],
358
+ "international_law": ["law"],
359
+ "jurisprudence": ["law"],
360
+ "logical_fallacies": ["philosophy"],
361
+ "machine_learning": ["computer science"],
362
+ "management": ["business"],
363
+ "marketing": ["business"],
364
+ "medical_genetics": ["health"],
365
+ "miscellaneous": ["other"],
366
+ "moral_disputes": ["philosophy"],
367
+ "moral_scenarios": ["philosophy"],
368
+ "nutrition": ["health"],
369
+ "philosophy": ["philosophy"],
370
+ "prehistory": ["history"],
371
+ "professional_accounting": ["other"],
372
+ "professional_law": ["law"],
373
+ "professional_medicine": ["health"],
374
+ "professional_psychology": ["psychology"],
375
+ "public_relations": ["politics"],
376
+ "security_studies": ["politics"],
377
+ "sociology": ["culture"],
378
+ "us_foreign_policy": ["politics"],
379
+ "virology": ["health"],
380
+ "world_religions": ["philosophy"],
381
+ }
382
+
383
+ categories = {
384
+ "Math": [
385
+ "math",
386
+ ],
387
+ "STEM": [
388
+ "physics",
389
+ "chemistry",
390
+ "biology",
391
+ "computer science",
392
+ "math",
393
+ "engineering",
394
+ ],
395
+ "humanities": ["history", "philosophy", "law"],
396
+ "social sciences": [
397
+ "politics",
398
+ "culture",
399
+ "economics",
400
+ "geography",
401
+ "psychology",
402
+ ],
403
+ "Other": ["other", "business", "health"],
404
+ "All": [
405
+ "physics",
406
+ "chemistry",
407
+ "biology",
408
+ "computer science",
409
+ "math",
410
+ "engineering",
411
+ "history",
412
+ "philosophy",
413
+ "law",
414
+ "politics",
415
+ "culture",
416
+ "economics",
417
+ "geography",
418
+ "psychology",
419
+ "other",
420
+ "business",
421
+ "health",
422
+ ],
423
+ "Test": ["culture"],
424
+ }
425
+
426
+ @classmethod
427
+ def suite(cls, chat=False):
428
+ finer_categories = (
429
+ pd.Series(cls.subcategories) # noqa # type: ignore
430
+ .explode()
431
+ .reset_index()
432
+ .set_index(0)
433
+ .groupby(0)
434
+ .agg(list)["index"]
435
+ .to_dict()
436
+ )
437
+ suite = {}
438
+ for k, v in cls.categories.items():
439
+ for subject in v:
440
+ suite[k] = [
441
+ Task(
442
+ ("lukaemon/mmlu", subcategories),
443
+ metric_name=("sustech/tlem", "mmlu"),
444
+ input_column=cls.input_column,
445
+ label_column=cls.label_column,
446
+ prompt=partial(cls.prompt_mmlu, chat=chat),
447
+ few_shot=0 if chat else 5,
448
+ few_shot_from="validation"
449
+ )
450
+ for subcategories in finer_categories[subject]
451
+ ]
452
+ return suite
tlem.py CHANGED
@@ -11,7 +11,8 @@ from evaluate.evaluation_suite import EvaluationSuite
11
  import evaluate
12
  import numpy as np
13
  import datasets
14
- from .tasks import Task, Metrics
 
15
  from .utils import is_equiv
16
 
17
  # %%
@@ -24,56 +25,35 @@ from .utils import is_equiv
24
 
25
  # TODO: Add BibTeX citation
26
  _CITATION = """\
27
- @InProceedings{huggingface:module,
28
- title = {A great new module},
29
- authors={huggingface, Inc.},
30
- year={2020}
31
- }
32
  """
33
 
34
  # TODO: Add description of the module here
35
  _DESCRIPTION = """\
36
- A simple measurement that returns the number of elements in dataset.
37
  """
38
 
39
 
40
  # TODO: Add description of the arguments of the module here
41
  _KWARGS_DESCRIPTION = """
42
- Calculates number of elements in dataset
43
- Args:
44
- data: list of elements.
45
- Returns:
46
- element_count: number of elements in dataset,
47
- Examples:
48
- >>> measure = evaluate.load("lvwerra/element_count")
49
- >>> measure.compute(["a", "b", "c")
50
- {"element_count": 3}
51
  """
52
 
53
  # TODO: Define external resources urls if needed
54
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
55
 
56
 
57
- @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
58
  class ReasoningMetric(evaluate.Metric):
59
  """TODO: Short description of my evaluation module."""
60
 
61
  def _info(self):
 
62
  features = datasets.Features(
63
  {
64
  "responses": datasets.Value("string"),
 
65
  "references": datasets.Value("string"),
66
  }
67
  )
68
 
69
- if self.config_name == "svamp":
70
- features = datasets.Features(
71
- {
72
- "responses": datasets.Value("string"),
73
- "references": datasets.Value("float"),
74
- }
75
- )
76
-
77
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
78
  return evaluate.EvaluationModuleInfo(
79
  # This is the description that will appear on the modules page.
@@ -90,38 +70,59 @@ class ReasoningMetric(evaluate.Metric):
90
  reference_urls=["http://path.to.reference.url/new_module"],
91
  )
92
 
93
- def _compute(self, responses, references, verbose=False):
94
- results = {}
95
  scores = getattr(Metrics, self.config_name)(responses, references)
96
- acc = np.asarray(scores).mean()
97
- results = {
98
- "accuracy": acc,
99
- "scores": scores,
100
- }
101
-
102
- if verbose:
103
- results["references"] = references
104
- results["answers"] = responses
105
- # results["scores"] = scores
106
-
107
  return results
108
 
109
 
110
  class Suite(EvaluationSuite):
111
  def run(
112
- self, model_or_pipeline: Any, prompt: str = "{instruction}"
 
 
113
  ) -> dict[str, float]:
114
  self.assert_suite_nonempty()
115
 
116
- results_all = {}
117
- for task in tqdm(self.suite, desc="Running tasks"):
118
- task_name = task.name
119
- results = task.run(model_or_pipeline)
120
- results_all[task_name] = results
121
- return results_all
122
-
123
- def __init__(self, name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  super().__init__(name)
 
 
 
 
 
125
 
126
  self.suite = [
127
  Task(
@@ -136,4 +137,3 @@ class Suite(EvaluationSuite):
136
 
137
 
138
  # %%
139
-
 
11
  import evaluate
12
  import numpy as np
13
  import datasets
14
+ import pandas as pd
15
+ from .tasks import *
16
  from .utils import is_equiv
17
 
18
  # %%
 
25
 
26
  # TODO: Add BibTeX citation
27
  _CITATION = """\
 
 
 
 
 
28
  """
29
 
30
  # TODO: Add description of the module here
31
  _DESCRIPTION = """\
 
32
  """
33
 
34
 
35
  # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
 
 
 
 
 
 
 
 
 
37
  """
38
 
39
  # TODO: Define external resources urls if needed
40
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
41
 
42
 
43
+ # @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
44
  class ReasoningMetric(evaluate.Metric):
45
  """TODO: Short description of my evaluation module."""
46
 
47
  def _info(self):
48
+ # if self.config_name in ["cmmlu"]:
49
  features = datasets.Features(
50
  {
51
  "responses": datasets.Value("string"),
52
+ # "responses": datasets.Sequence(datasets.Value("float")),
53
  "references": datasets.Value("string"),
54
  }
55
  )
56
 
 
 
 
 
 
 
 
 
57
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
58
  return evaluate.EvaluationModuleInfo(
59
  # This is the description that will appear on the modules page.
 
70
  reference_urls=["http://path.to.reference.url/new_module"],
71
  )
72
 
73
+ def _compute(self, responses, references):
 
74
  scores = getattr(Metrics, self.config_name)(responses, references)
75
+ results = {"Accuracy": np.nanmean(scores)}
76
+ logging.info(results)
 
 
 
 
 
 
 
 
 
77
  return results
78
 
79
 
80
  class Suite(EvaluationSuite):
81
  def run(
82
+ self,
83
+ model_or_pipeline: Any,
84
+ name="tlem",
85
  ) -> dict[str, float]:
86
  self.assert_suite_nonempty()
87
 
88
+ def run_tasks(tasks):
89
+ for task in tqdm(tasks):
90
+ if task.name not in self.cached_result:
91
+ self.cached_result[task.name] = task.run(model_or_pipeline)
92
+ results = [self.cached_result[task.name] for task in tasks]
93
+ return pd.DataFrame(results).mean().to_dict()
94
+
95
+ if isinstance(self.suite, dict):
96
+ for category, tasks in tqdm(self.suite.items()):
97
+ logging.warning(f"Combined results: {category}:{run_tasks(tasks)}")
98
+ else:
99
+ logging.warning(f"Combined results: {run_tasks(self.suite)}")
100
+
101
+ return self.cached_result
102
+
103
+ def add(self, name):
104
+ chat = False
105
+ match name:
106
+ case _ if "chat" in name:
107
+ chat = True
108
+ match name:
109
+ case _ if name.startswith("mmlu"):
110
+ suite = MMLU.suite(chat=chat)
111
+ case _ if name.startswith("cmmlu"):
112
+ suite = CMMLU.suite(chat=chat)
113
+ match name:
114
+ case _ if "test" in name:
115
+ suite = suite["Test"]
116
+
117
+ self.suite = suite
118
+
119
+ def __init__(self, name="tlem"):
120
  super().__init__(name)
121
+ self.cached_result = {}
122
+
123
+ match self.name:
124
+ case "cmmlu":
125
+ pass
126
 
127
  self.suite = [
128
  Task(
 
137
 
138
 
139
  # %%
 
utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import logging
2
  import re
 
 
3
 
4
  NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
5
  NUMERIC_IN_ZH = (
@@ -7,17 +9,43 @@ NUMERIC_IN_ZH = (
7
  )
8
 
9
 
10
- def extract_choice_ans(text):
11
- pattern1 = r"\b[ABCDabcd]\b"
12
- pattern2 = r"\([ABCDabcd]\)"
13
- matches1 = re.findall(pattern1, text)
14
- matches2 = re.findall(pattern2, text)
15
- matches = matches1 + matches2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def standardize(ans):
18
- return ans if len(ans) == 1 else ans[1]
19
 
20
- return standardize(matches[-1]).lower() if matches else "_"
 
 
 
 
 
21
 
22
 
23
  def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
 
1
  import logging
2
  import re
3
+ import numpy as np
4
+ from typing import Any
5
 
6
  NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
7
  NUMERIC_IN_ZH = (
 
9
  )
10
 
11
 
12
+ def extract_choice(gen):
13
+ # answer is A | choice is A | choose A
14
+ res = re.search(
15
+ r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b",
16
+ gen,
17
+ )
18
+
19
+ # A is correct | A is right
20
+ if res is None:
21
+ res = re.search(
22
+ r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b",
23
+ gen,
24
+ )
25
+
26
+ # straight answer: A
27
+ if res is None:
28
+ res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
29
+
30
+ # simply extract the first appearred letter
31
+ if res is None:
32
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
33
+
34
+ if res is None:
35
+ res = "A"
36
+
37
+ if isinstance(res, str):
38
+ return res
39
+
40
+ return res.group(1)
41
 
 
 
42
 
43
+ def decode_choice(responses: list[Any]):
44
+ num_choices = responses[0].shape[0]
45
+ choices = np.argmax(np.asarray(responses), axis=1)
46
+ responses = np.array(list("ABCDEFGHIJKL"[:num_choices]))[choices]
47
+ # return (responses == np.array(references)).mean()
48
+ return responses
49
 
50
 
51
  def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str: