File size: 4,336 Bytes
4c7982b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable

from utils import (
    NUMERIC_IN_ZH,
    extract_choice_ans,
    extract_numeric,
    get_answer,
    is_equiv,
)

from evaluate import load

TextGenerationPipeline = Callable[[Iterable[str]], list[str]]


def fake_pipeline(prompts: Iterable[str]) -> list[str]:
    return [prompt for prompt in tqdm(prompts)]


@dataclass
class Task:
    dataset_name: str | tuple[str, str] = ("gsm8k", "main")
    split: str = "test"
    # metrics: list[str] = field(default_factory=list)
    metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
    input_column: str = "question"
    label_column: str = "answer"
    prompt: Optional[Callable | str] = None

    @cached_property
    def name(self):
        return (
            self.dataset_name
            if isinstance(self.dataset_name, str)
            else self.dataset_name[0]
        ) + f"-{self.split}"

    @cached_property
    def samples(self):
        return self.dataset[self.input_column]

    @cached_property
    def dataset(self):
        ds = load_dataset(
            *self.dataset_name
            if isinstance(self.dataset_name, tuple)
            else self.dataset_name,
            split=self.split,
        )
        if self.prompt is not None:
            ds = ds.map(
                lambda example: {
                    self.input_column: self.prompt.format(
                        input_column=example[self.input_column]
                    )
                }
                if isinstance(self.prompt, str)
                else self.prompt(example),
            )

        return ds

    @cached_property
    def metric(self):
        metric = (
            load(self.metric_name)
            if isinstance(self.metric_name, str)
            else load(*self.metric_name)
        )
        return metric

    def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
        outputs = pipeline(self.samples)
        return self.metric.compute(
            responses=outputs, references=self.dataset[self.label_column]
        )


class Metrics:
    def gsm8k(responses: list[str], answers: list[str | int]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response)
            gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
            scores.append(1.0 * (pred == gold))
        return scores

    def MATH(responses: list[str], answers: list[str]):
        scores = []

        for response, answer in zip(responses, answers):
            indices = [pos for pos, char in enumerate(response) if char == "$"]
            if len(indices) <= 2:
                scores.append(0)
                continue
            else:
                result = response[indices[-2] + 1 : indices[-1]]
                gold = get_answer(answer)
                scores.append(1.0 * is_equiv(result, gold))

        return scores

    def math23k(responses: list[str], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
            scores.append(1.0 * (pred == gold))
        return scores

    def gsm8k_zh(responses: list[str], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = extract_numeric(answer)
            scores.append(1.0 * (pred == gold))
        return scores

    def svamp(responses: list[float], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = answer
            scores.append(1.0 * (float(pred) == gold))
        return scores

    def mmlu(responses, answers):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_choice_ans(response)
            gold = answer.lower()
            scores.append(1.0 * (pred == gold))
        return scores