facat commited on
Commit
4c7982b
1 Parent(s): c250b54
Files changed (2) hide show
  1. tasks.py +137 -0
  2. tlem.py +5 -151
tasks.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from datasets import load_dataset, Dataset
3
+ from functools import cached_property
4
+ from tqdm.auto import tqdm
5
+ from typing import Any, Optional, Protocol, Iterable, Callable
6
+
7
+ from utils import (
8
+ NUMERIC_IN_ZH,
9
+ extract_choice_ans,
10
+ extract_numeric,
11
+ get_answer,
12
+ is_equiv,
13
+ )
14
+
15
+ from evaluate import load
16
+
17
+ TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
18
+
19
+
20
+ def fake_pipeline(prompts: Iterable[str]) -> list[str]:
21
+ return [prompt for prompt in tqdm(prompts)]
22
+
23
+
24
+ @dataclass
25
+ class Task:
26
+ dataset_name: str | tuple[str, str] = ("gsm8k", "main")
27
+ split: str = "test"
28
+ # metrics: list[str] = field(default_factory=list)
29
+ metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
30
+ input_column: str = "question"
31
+ label_column: str = "answer"
32
+ prompt: Optional[Callable | str] = None
33
+
34
+ @cached_property
35
+ def name(self):
36
+ return (
37
+ self.dataset_name
38
+ if isinstance(self.dataset_name, str)
39
+ else self.dataset_name[0]
40
+ ) + f"-{self.split}"
41
+
42
+ @cached_property
43
+ def samples(self):
44
+ return self.dataset[self.input_column]
45
+
46
+ @cached_property
47
+ def dataset(self):
48
+ ds = load_dataset(
49
+ *self.dataset_name
50
+ if isinstance(self.dataset_name, tuple)
51
+ else self.dataset_name,
52
+ split=self.split,
53
+ )
54
+ if self.prompt is not None:
55
+ ds = ds.map(
56
+ lambda example: {
57
+ self.input_column: self.prompt.format(
58
+ input_column=example[self.input_column]
59
+ )
60
+ }
61
+ if isinstance(self.prompt, str)
62
+ else self.prompt(example),
63
+ )
64
+
65
+ return ds
66
+
67
+ @cached_property
68
+ def metric(self):
69
+ metric = (
70
+ load(self.metric_name)
71
+ if isinstance(self.metric_name, str)
72
+ else load(*self.metric_name)
73
+ )
74
+ return metric
75
+
76
+ def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
77
+ outputs = pipeline(self.samples)
78
+ return self.metric.compute(
79
+ responses=outputs, references=self.dataset[self.label_column]
80
+ )
81
+
82
+
83
+ class Metrics:
84
+ def gsm8k(responses: list[str], answers: list[str | int]):
85
+ scores = []
86
+ for response, answer in zip(responses, answers):
87
+ pred = extract_numeric(response)
88
+ gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
89
+ scores.append(1.0 * (pred == gold))
90
+ return scores
91
+
92
+ def MATH(responses: list[str], answers: list[str]):
93
+ scores = []
94
+
95
+ for response, answer in zip(responses, answers):
96
+ indices = [pos for pos, char in enumerate(response) if char == "$"]
97
+ if len(indices) <= 2:
98
+ scores.append(0)
99
+ continue
100
+ else:
101
+ result = response[indices[-2] + 1 : indices[-1]]
102
+ gold = get_answer(answer)
103
+ scores.append(1.0 * is_equiv(result, gold))
104
+
105
+ return scores
106
+
107
+ def math23k(responses: list[str], answers: list[str]):
108
+ scores = []
109
+ for response, answer in zip(responses, answers):
110
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
111
+ gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
112
+ scores.append(1.0 * (pred == gold))
113
+ return scores
114
+
115
+ def gsm8k_zh(responses: list[str], answers: list[str]):
116
+ scores = []
117
+ for response, answer in zip(responses, answers):
118
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
119
+ gold = extract_numeric(answer)
120
+ scores.append(1.0 * (pred == gold))
121
+ return scores
122
+
123
+ def svamp(responses: list[float], answers: list[str]):
124
+ scores = []
125
+ for response, answer in zip(responses, answers):
126
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
127
+ gold = answer
128
+ scores.append(1.0 * (float(pred) == gold))
129
+ return scores
130
+
131
+ def mmlu(responses, answers):
132
+ scores = []
133
+ for response, answer in zip(responses, answers):
134
+ pred = extract_choice_ans(response)
135
+ gold = answer.lower()
136
+ scores.append(1.0 * (pred == gold))
137
+ return scores
tlem.py CHANGED
@@ -8,6 +8,11 @@ except Exception as e:
8
  from typing import Any, Optional, Protocol, Iterable, Callable
9
  from tqdm.auto import tqdm
10
  from evaluate.evaluation_suite import EvaluationSuite
 
 
 
 
 
11
 
12
  # %%
13
 
@@ -15,150 +20,6 @@ from evaluate.evaluation_suite import EvaluationSuite
15
 
16
  # %load_ext ipytorch
17
  # %ls
18
- from utils import (
19
- NUMERIC_IN_ZH,
20
- extract_choice_ans,
21
- extract_numeric,
22
- get_answer,
23
- is_equiv,
24
- )
25
-
26
-
27
- from dataclasses import dataclass, field
28
- from datasets import load_dataset, Dataset
29
- from functools import cached_property
30
-
31
-
32
- TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
33
-
34
-
35
- from evaluate import load
36
-
37
-
38
- def fake_pipeline(prompts: Iterable[str]) -> list[str]:
39
- return [prompt for prompt in tqdm(prompts)]
40
-
41
-
42
- @dataclass
43
- class Task:
44
- dataset_name: str | tuple[str, str] = ("gsm8k", "main")
45
- split: str = "test"
46
- # metrics: list[str] = field(default_factory=list)
47
- metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
48
- input_column: str = "question"
49
- label_column: str = "answer"
50
- prompt: Optional[Callable | str] = None
51
-
52
- @cached_property
53
- def name(self):
54
- return (
55
- self.dataset_name
56
- if isinstance(self.dataset_name, str)
57
- else self.dataset_name[0]
58
- ) + f"-{self.split}"
59
-
60
- @cached_property
61
- def samples(self):
62
- return self.dataset[self.input_column]
63
-
64
- @cached_property
65
- def dataset(self):
66
- ds = load_dataset(
67
- *self.dataset_name
68
- if isinstance(self.dataset_name, tuple)
69
- else self.dataset_name,
70
- split=self.split,
71
- )
72
- if self.prompt is not None:
73
- ds = ds.map(
74
- lambda example: {
75
- self.input_column: self.prompt.format(
76
- input_column=example[self.input_column]
77
- )
78
- }
79
- if isinstance(self.prompt, str)
80
- else self.prompt(example),
81
- )
82
-
83
- return ds
84
-
85
- @cached_property
86
- def metric(self):
87
- metric = (
88
- load(self.metric_name)
89
- if isinstance(self.metric_name, str)
90
- else load(*self.metric_name)
91
- )
92
- return metric
93
-
94
- def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
95
- outputs = pipeline(self.samples)
96
- return self.metric.compute(
97
- responses=outputs, references=self.dataset[self.label_column]
98
- )
99
-
100
-
101
- class Metrics:
102
- def gsm8k(responses: list[str], answers: list[str | int]):
103
- scores = []
104
- for response, answer in zip(responses, answers):
105
- pred = extract_numeric(response)
106
- gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
107
- scores.append(1.0 * (pred == gold))
108
- return scores
109
-
110
- def MATH(responses: list[str], answers: list[str]):
111
- scores = []
112
-
113
- for response, answer in zip(responses, answers):
114
- indices = [pos for pos, char in enumerate(response) if char == "$"]
115
- if len(indices) <= 2:
116
- scores.append(0)
117
- continue
118
- else:
119
- result = response[indices[-2] + 1 : indices[-1]]
120
- gold = get_answer(answer)
121
- scores.append(1.0 * is_equiv(result, gold))
122
-
123
- return scores
124
-
125
- def math23k(responses: list[str], answers: list[str]):
126
- scores = []
127
- for response, answer in zip(responses, answers):
128
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
129
- gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
130
- scores.append(1.0 * (pred == gold))
131
- return scores
132
-
133
- def gsm8k_zh(responses: list[str], answers: list[str]):
134
- scores = []
135
- for response, answer in zip(responses, answers):
136
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
137
- gold = extract_numeric(answer)
138
- scores.append(1.0 * (pred == gold))
139
- return scores
140
-
141
- def svamp(responses: list[float], answers: list[str]):
142
- scores = []
143
- for response, answer in zip(responses, answers):
144
- pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
145
- gold = answer
146
- scores.append(1.0 * (float(pred) == gold))
147
- return scores
148
-
149
- def mmlu(responses, answers):
150
- scores = []
151
- for response, answer in zip(responses, answers):
152
- pred = extract_choice_ans(response)
153
- gold = answer.lower()
154
- scores.append(1.0 * (pred == gold))
155
- return scores
156
-
157
-
158
- import evaluate
159
- import numpy as np
160
-
161
- import datasets
162
 
163
 
164
  # TODO: Add BibTeX citation
@@ -276,10 +137,3 @@ class Suite(EvaluationSuite):
276
 
277
  # %%
278
 
279
- if __name__ == "__main__":
280
- # metric = load("sustech/tlem", "gsm8k")
281
- # output = metric.compute(responses=["answer is 2", "1+2"], references=["2", "3"])
282
- # logging.info(output)
283
- suite = EvaluationSuite.load("sustech/tlem")
284
- suite.run(fake_pipeline)
285
- # %%
 
8
  from typing import Any, Optional, Protocol, Iterable, Callable
9
  from tqdm.auto import tqdm
10
  from evaluate.evaluation_suite import EvaluationSuite
11
+ import evaluate
12
+ import numpy as np
13
+ import datasets
14
+ from tasks import Task, Metrics, fake_pipeline
15
+ from utils import is_equiv
16
 
17
  # %%
18
 
 
20
 
21
  # %load_ext ipytorch
22
  # %ls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # TODO: Add BibTeX citation
 
137
 
138
  # %%
139