File size: 4,886 Bytes
8af54b8
 
 
 
 
 
 
 
a034e31
a6d7b1c
 
4c7982b
 
 
be1543a
 
044ed98
8af54b8
 
 
 
 
 
be1543a
8af54b8
 
 
be1543a
8af54b8
 
 
 
 
 
 
 
33a6f85
 
 
8af54b8
 
 
 
 
 
 
 
 
a034e31
 
 
 
 
 
 
 
 
 
 
 
868c1b2
a034e31
 
 
 
 
be1543a
a034e31
 
8af54b8
e01a5f6
 
a6d7b1c
845a45a
 
a6d7b1c
be1543a
 
a6d7b1c
 
 
be1543a
33a6f85
 
be1543a
 
 
 
 
 
33a6f85
 
 
be1543a
 
 
 
 
 
33a6f85
 
 
be1543a
 
 
 
 
 
 
 
 
18cd4ae
845a45a
 
 
 
 
 
141ccb9
 
 
 
 
 
 
 
 
 
 
 
845a45a
 
 
 
 
 
 
141ccb9
 
 
845a45a
868c1b2
be1543a
 
a6d7b1c
be1543a
33a6f85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# %%

try:
    from ipytorch import logging
except Exception as e:
    import logging

from typing import Any, Optional, Protocol, Iterable, Callable
from numpy.lib import extract
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import is_equiv


class ReasoningMetric(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # if self.config_name in ["cmmlu"]:
        features = datasets.Features(
            {
                "responses": datasets.Value("string"),
                # "responses": datasets.Sequence(datasets.Value("float")),
                "references": datasets.Value("string"),
            }
        )

        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.EvaluationModuleInfo(
            # This is the description that will appear on the modules page.
            # module_type="measurement",
            description="",
            citation="",
            inputs_description="",
            # This defines the format of each prediction and reference
            features=features,
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _compute(self, responses, references, verbose=False):
        extract_responses, extract_references = getattr(Metrics, self.config_name)(
            responses, references
        )
        df = pd.DataFrame(
            {
                "responses": responses,
                "references": references,
            }
        )
        df["extract_responses"] = extract_responses
        df["extract_references"] = extract_references
        # print(df)
        results = {
            "Accuracy": (df["extract_references"] == df["extract_responses"])
            .astype(int)
            .mean(),
        }
        logging.info(results)
        if verbose:
            results["df"] = df
        return results


class Suite(EvaluationSuite):
    task_class = Task

    def run(
        self,
        model_or_pipeline: Any,
    ) -> dict[str, float]:
        self.assert_suite_nonempty()

        def run_tasks(tasks):
            for task in (bar := tqdm(tasks, leave=False)):
                bar.desc = f"complete {task.name}."
                if task.name not in self.cached_result:
                    self.cached_result[task.name] = task.run(model_or_pipeline)
            results = [self.cached_result[task.name] for task in tasks]
            return pd.DataFrame(results).mean().to_dict()

        if isinstance(self.suite, dict):
            for category, tasks in (bar := tqdm(self.suite.items())):
                bar.desc = f"complete {category}."
                logging.warning(f"Combined results {category}: {run_tasks(tasks)}")
        else:
            logging.warning(f"Combined results: {run_tasks(self.suite)}")

        return self.cached_result

    def add(self, name):
        self.load(name)

    def load(self, name):
        chat = False
        match name:
            case _ if "chat" in name:
                chat = True
        match name:
            case _ if name.startswith("mmlu"):
                suite = MMLU.suite(chat=chat)
            case _ if name.startswith("cmmlu"):
                suite = CMMLU.suite(chat=chat)
            case "gsm8k":
                suite = Task(
                    dataset_name=("gsm8k", "main"),
                    metric_name=("sustech/tlem", "gsm8k"),
                    input_column="question",
                    label_column="answer",
                )
            case "bbh":
                suite = BBH.suite()
            case "arc":
                suite = ARC.suite()
            case "hellaswag":
                suite = HellaSwag.suite()
            case "drop":
                suite = DROP.suite()
            case "winogrande":
                suite = Winogrande.suite()
            case _ if name.startswith("ceval"):
                suite = CEVAL.suite(chat=chat)
            case "mt_bench":
                suite = Task(
                    dataset_name="SUSTech/mt_bench_judge",
                    split="train",
                    prompt=mt_bench_prompt
                    # metric_name=("sustech/tlem", "gsm8k"),
                )
        match name:
            case _ if "test" in name:
                suite = suite["Test"]

        self.suite = [suite] if isinstance(suite, Task) else suite

    def __init__(self, name="tlem"):
        super().__init__(name)
        self.cached_result = {}
        self.suite = []