facat commited on
Commit
a6d7b1c
1 Parent(s): e01a5f6
Files changed (1) hide show
  1. tlem.py +56 -9
tlem.py CHANGED
@@ -6,6 +6,8 @@ except Exception as e:
6
  import logging
7
 
8
  from typing import Any, Optional, Protocol, Iterable, Callable
 
 
9
 
10
  # %%
11
 
@@ -33,14 +35,18 @@ TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
33
  from evaluate import load
34
 
35
 
 
 
 
 
36
  @dataclass
37
  class Task:
38
- dataset_name: str = "gsm8k"
39
- dataset_params: dict = field(default_factory=dict)
40
  # metrics: list[str] = field(default_factory=list)
41
- metric_name: str | tuple[str, str] = "gsm8k"
42
  input_column: str = "question"
43
- label_column: str = "reference"
44
  prompt: Optional[Callable | str] = None
45
 
46
  @cached_property
@@ -49,7 +55,12 @@ class Task:
49
 
50
  @cached_property
51
  def dataset(self):
52
- ds = load_dataset(self.dataset_name, **self.dataset_params)
 
 
 
 
 
53
  if self.prompt is not None:
54
  ds = ds.map(
55
  lambda example: {
@@ -72,9 +83,11 @@ class Task:
72
  )
73
  return metric
74
 
75
- def run(self, pipeline: TextGenerationPipeline):
76
  outputs = pipeline(self.samples)
77
- return self.metric.compute(outputs, self.dataset[self.label_column])
 
 
78
 
79
 
80
  class Metrics:
@@ -224,7 +237,41 @@ class ReasoningMetric(evaluate.Metric):
224
 
225
  return results
226
 
227
- # %%
228
 
229
- load("sustech/tlem", "gsm8k")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
 
 
 
 
 
 
 
 
 
 
6
  import logging
7
 
8
  from typing import Any, Optional, Protocol, Iterable, Callable
9
+ from tqdm.auto import tqdm
10
+ from evaluate.evaluation_suite import EvaluationSuite
11
 
12
  # %%
13
 
 
35
  from evaluate import load
36
 
37
 
38
+ def fake_pipeline(prompts: Iterable[str]) -> list[str]:
39
+ return [prompt for prompt in tqdm(prompts)]
40
+
41
+
42
  @dataclass
43
  class Task:
44
+ dataset_name: str | tuple[str, str] = ("gsm8k", "main")
45
+ split: str = "test"
46
  # metrics: list[str] = field(default_factory=list)
47
+ metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
48
  input_column: str = "question"
49
+ label_column: str = "answer"
50
  prompt: Optional[Callable | str] = None
51
 
52
  @cached_property
 
55
 
56
  @cached_property
57
  def dataset(self):
58
+ ds = load_dataset(
59
+ *self.dataset_name
60
+ if isinstance(self.dataset_name, tuple)
61
+ else self.dataset_name,
62
+ split=self.split
63
+ )
64
  if self.prompt is not None:
65
  ds = ds.map(
66
  lambda example: {
 
83
  )
84
  return metric
85
 
86
+ def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
87
  outputs = pipeline(self.samples)
88
+ return self.metric.compute(
89
+ responses=outputs, references=self.dataset[self.label_column]
90
+ )
91
 
92
 
93
  class Metrics:
 
237
 
238
  return results
239
 
 
240
 
241
+ class Suite(EvaluationSuite):
242
+ def run(
243
+ self, model_or_pipeline: Any, prompt: str = "{instruction}"
244
+ ) -> dict[str, float]:
245
+ self.assert_suite_nonempty()
246
+
247
+ results_all = {}
248
+ for task in tqdm(self.suite, desc="Running tasks"):
249
+ task_name = task.name
250
+ results = task.run(model_or_pipeline)
251
+ results_all[task_name] = results
252
+ return results_all
253
+
254
+ def __init__(self, name):
255
+ super().__init__(name)
256
+
257
+ self.suite = [
258
+ Task(
259
+ dataset_name=("gsm8k", "main"),
260
+ metric_name=("sustech/tlem", "gsm8k"),
261
+ input_column="question",
262
+ label_column="answer",
263
+ )
264
+ # TASK_REGISTRY["gsm8k"],
265
+ # TASK_REGISTRY["competition_math"],
266
+ ]
267
+
268
 
269
+ # %%
270
+
271
+ if __name__ == "__main__":
272
+ # metric = load("sustech/tlem", "gsm8k")
273
+ # output = metric.compute(responses=["answer is 2", "1+2"], references=["2", "3"])
274
+ # logging.info(output)
275
+ suite = EvaluationSuite.load("sustech/tlem")
276
+ suite.run(fake_pipeline)
277
+ # %%