Spaces:
Running
Running
from dataclasses import dataclass, field | |
from datasets import load_dataset, Dataset | |
from functools import cached_property | |
from tqdm.auto import tqdm | |
from typing import Any, Optional, Protocol, Iterable, Callable | |
import logging | |
import pandas as pd | |
from functools import partial | |
from .utils import * | |
from evaluate import load | |
from collections import defaultdict | |
def fake_pipeline(prompts: Iterable[str]) -> list[str]: | |
return [prompt for prompt in tqdm(prompts)] | |
class Task: | |
dataset_name: str | tuple[str, str] = ("gsm8k", "main") | |
split: str = "test" | |
# metrics: list[str] = field(default_factory=list) | |
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k") | |
input_column: str = "question" | |
label_column: str = "answer" | |
prompt: Optional[Callable | str] = None | |
few_shot: int = 0 | |
few_shot_from: Optional[str] = None | |
# results: dict[str, Any] = field(default_factory=dict) | |
def __post_init__(self): | |
names = ( | |
[self.dataset_name] | |
if isinstance(self.dataset_name, str) | |
else list(self.dataset_name) | |
) | |
names[0] = names[0].split("/")[-1] | |
self.name = "-".join(names) + f"-{self.split}" | |
if isinstance(self.prompt, str): | |
self.prompt = lambda example: { | |
self.input_column: self.prompt.format( | |
input_column=example[self.input_column] | |
) | |
} | |
def samples(self): | |
return self.dataset[self.input_column] | |
def dataset(self): | |
ds = load_dataset( | |
*self.dataset_name | |
if isinstance(self.dataset_name, tuple) | |
else self.dataset_name, | |
# split=self.split, | |
) | |
test_ds = ds[self.split] | |
if self.prompt is not None: | |
test_ds = test_ds.map(self.prompt) | |
if self.few_shot: | |
if self.few_shot_from is None: | |
for name in ["train", "validation", "val", "dev"]: | |
if name in ds: | |
self.few_shot_from = name | |
break | |
shots = ds[self.few_shot_from].select(range(self.few_shot)) | |
if self.prompt is not None: | |
shots = shots.map(self.prompt) | |
shots = shots.map( | |
lambda example: { | |
self.input_column: example[self.input_column] | |
+ example[self.label_column], | |
} | |
)[self.input_column] | |
few_shot_prompts = "\n\n".join(shots) | |
test_ds = test_ds.map( | |
lambda example: { | |
self.input_column: few_shot_prompts | |
+ "\n\n" | |
+ example[self.input_column], | |
} | |
) | |
return test_ds | |
def metric(self): | |
metric = ( | |
load(self.metric_name) | |
if isinstance(self.metric_name, str) | |
else load(*self.metric_name) | |
) | |
return metric | |
def run( | |
self, | |
pipeline, | |
): | |
if (outputs := pipeline(self.samples)) is None: | |
logging.warning("pipeline returns None") | |
return | |
self.outputs = outputs | |
try: | |
result = self.metric._compute( | |
responses=outputs, references=self.dataset[self.label_column] | |
) | |
except Exception as e: | |
result = self.metric.compute( | |
responses=outputs, references=self.dataset[self.label_column] | |
) | |
# if log: | |
# name = name or pipeline.__name__ | |
# self.results[name] = result | |
return result | |
def multichoice(responses: Any, references: list[str]): | |
if isinstance(responses[0], str): | |
responses = [extract_choice(response) for response in responses] | |
else: | |
responses = decode_choice(responses) | |
# return [ | |
# int(response == reference) for reference, response in zip(references, responses) | |
# ] | |
return responses, references | |
class Metrics: | |
cmmlu = multichoice | |
mmlu = multichoice | |
def gsm8k(responses: list[str], answers: list[str | int]): | |
# scores = [] | |
# for response, answer in zip(responses, answers): | |
# pred = extract_numeric(response) | |
# gold = extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
# scores.append(1.0 * (pred == gold)) | |
responses = [extract_numeric(response) for response in responses] | |
answers = [ | |
extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
for answer in answers | |
] | |
return responses, answers | |
def MATH(responses: list[str], answers: list[str]): | |
scores = [] | |
for response, answer in zip(responses, answers): | |
indices = [pos for pos, char in enumerate(response) if char == "$"] | |
if len(indices) <= 2: | |
scores.append(0) | |
continue | |
else: | |
result = response[indices[-2] + 1 : indices[-1]] | |
gold = get_answer(answer) | |
scores.append(1.0 * is_equiv(result, gold)) | |
return scores | |
def math23k(responses: list[str], answers: list[str]): | |
scores = [] | |
for response, answer in zip(responses, answers): | |
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH) | |
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH) | |
scores.append(1.0 * (pred == gold)) | |
return scores | |
class CMMLU: | |
input_column = "prompt" | |
label_column = "Answer" | |
def prompt_cmmlu(example, chat=False): | |
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:" | |
prompt = prefix + example["Question"] | |
for choice in list("ABCD"): | |
prompt += f"\n{choice}. {example[choice]}" | |
prompt += "\n答案:" | |
return {"prompt": prompt} | |
subcategories = { | |
"agronomy": ["other"], | |
"anatomy": ["biology"], | |
"ancient_chinese": ["linguistics", "china specific"], | |
"arts": ["arts"], | |
"astronomy": ["physics"], | |
"business_ethics": ["business"], | |
"chinese_civil_service_exam": ["politics", "china specific"], | |
"chinese_driving_rule": ["other", "china specific"], | |
"chinese_food_culture": ["culture", "china specific"], | |
"chinese_foreign_policy": ["politics", "china specific"], | |
"chinese_history": ["history", "china specific"], | |
"chinese_literature": ["literature", "china specific"], | |
"chinese_teacher_qualification": ["education", "china specific"], | |
"college_actuarial_science": ["math"], | |
"college_education": ["education"], | |
"college_engineering_hydrology": ["engineering"], | |
"college_law": ["law"], | |
"college_mathematics": ["math"], | |
"college_medical_statistics": ["statistics"], | |
"clinical_knowledge": ["other"], | |
"college_medicine": ["other"], | |
"computer_science": ["computer science"], | |
"computer_security": ["other"], | |
"conceptual_physics": ["physics"], | |
"construction_project_management": ["other", "china specific"], | |
"economics": ["economics"], | |
"education": ["education"], | |
"elementary_chinese": ["linguistics", "china specific"], | |
"elementary_commonsense": ["other", "china specific"], | |
"elementary_information_and_technology": ["other"], | |
"electrical_engineering": ["engineering"], | |
"elementary_mathematics": ["math"], | |
"ethnology": ["culture", "china specific"], | |
"food_science": ["other"], | |
"genetics": ["biology"], | |
"global_facts": ["global"], | |
"high_school_biology": ["biology"], | |
"high_school_chemistry": ["chemistry"], | |
"high_school_geography": ["geography"], | |
"high_school_mathematics": ["math"], | |
"high_school_physics": ["physics"], | |
"high_school_politics": ["politics", "china specific"], | |
"human_sexuality": ["other"], | |
"international_law": ["law"], | |
"journalism": ["sociology"], | |
"jurisprudence": ["law"], | |
"legal_and_moral_basis": ["other"], | |
"logical": ["philosophy"], | |
"machine_learning": ["computer science"], | |
"management": ["business"], | |
"marketing": ["business"], | |
"marxist_theory": ["philosophy"], | |
"modern_chinese": ["linguistics", "china specific"], | |
"nutrition": ["other"], | |
"philosophy": ["philosophy"], | |
"professional_accounting": ["business"], | |
"professional_law": ["law"], | |
"professional_medicine": ["other"], | |
"professional_psychology": ["psychology"], | |
"public_relations": ["politics"], | |
"security_study": ["politics"], | |
"sociology": ["culture"], | |
"sports_science": ["other"], | |
"traditional_chinese_medicine": ["other", "china specific"], | |
"virology": ["biology"], | |
"world_history": ["history"], | |
"world_religions": ["global"], | |
} | |
categories = { | |
"STEM": [ | |
"physics", | |
"chemistry", | |
"biology", | |
"computer science", | |
"math", | |
"engineering", | |
"statistics", | |
], | |
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], | |
"Social Science": [ | |
"linguistics", | |
"business", | |
"politics", | |
"culture", | |
"economics", | |
"geography", | |
"psychology", | |
"education", | |
"sociology", | |
], | |
"Other": ["other"], | |
"China specific": ["china specific"], | |
"Test": ["computer science"], | |
} | |
def suite(cls, chat=False): | |
finer_categories = ( | |
pd.Series(cls.subcategories) # noqa # type: ignore | |
.explode() | |
.reset_index() | |
.set_index(0) | |
.groupby(0) | |
.agg(list)["index"] | |
.to_dict() | |
) | |
suite = defaultdict(list) | |
for k, v in cls.categories.items(): | |
for subject in v: | |
suite[k].extend( | |
[ | |
Task( | |
("haonan-li/cmmlu", subcategories), | |
metric_name=("sustech/tlem", "cmmlu"), | |
input_column=cls.input_column, | |
label_column=cls.label_column, | |
prompt=partial(cls.prompt_cmmlu, chat=chat), | |
few_shot=0 if chat else 5, | |
few_shot_from="dev", | |
) | |
for subcategories in finer_categories[subject] | |
] | |
) | |
return suite | |
class MMLU: | |
input_column = "prompt" | |
label_column = "target" | |
def prompt_mmlu(cls, example, chat=False): | |
prefix = ( | |
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" | |
if chat | |
else "Question: " | |
) | |
prompt = prefix + example["input"] | |
for choice in list("ABCD"): | |
prompt += f"\n{choice}. {example[choice]}" | |
prompt += "\nAnswer:" | |
return {"prompt": prompt} | |
subcategories = { | |
"abstract_algebra": ["math"], | |
"anatomy": ["health"], | |
"astronomy": ["physics"], | |
"business_ethics": ["business"], | |
"clinical_knowledge": ["health"], | |
"college_biology": ["biology"], | |
"college_chemistry": ["chemistry"], | |
"college_computer_science": ["computer science"], | |
"college_mathematics": ["math"], | |
"college_medicine": ["health"], | |
"college_physics": ["physics"], | |
"computer_security": ["computer science"], | |
"conceptual_physics": ["physics"], | |
"econometrics": ["economics"], | |
"electrical_engineering": ["engineering"], | |
"elementary_mathematics": ["math"], | |
"formal_logic": ["philosophy"], | |
"global_facts": ["other"], | |
"high_school_biology": ["biology"], | |
"high_school_chemistry": ["chemistry"], | |
"high_school_computer_science": ["computer science"], | |
"high_school_european_history": ["history"], | |
"high_school_geography": ["geography"], | |
"high_school_government_and_politics": ["politics"], | |
"high_school_macroeconomics": ["economics"], | |
"high_school_mathematics": ["math"], | |
"high_school_microeconomics": ["economics"], | |
"high_school_physics": ["physics"], | |
"high_school_psychology": ["psychology"], | |
"high_school_statistics": ["math"], | |
"high_school_us_history": ["history"], | |
"high_school_world_history": ["history"], | |
"human_aging": ["health"], | |
"human_sexuality": ["culture"], | |
"international_law": ["law"], | |
"jurisprudence": ["law"], | |
"logical_fallacies": ["philosophy"], | |
"machine_learning": ["computer science"], | |
"management": ["business"], | |
"marketing": ["business"], | |
"medical_genetics": ["health"], | |
"miscellaneous": ["other"], | |
"moral_disputes": ["philosophy"], | |
"moral_scenarios": ["philosophy"], | |
"nutrition": ["health"], | |
"philosophy": ["philosophy"], | |
"prehistory": ["history"], | |
"professional_accounting": ["other"], | |
"professional_law": ["law"], | |
"professional_medicine": ["health"], | |
"professional_psychology": ["psychology"], | |
"public_relations": ["politics"], | |
"security_studies": ["politics"], | |
"sociology": ["culture"], | |
"us_foreign_policy": ["politics"], | |
"virology": ["health"], | |
"world_religions": ["philosophy"], | |
} | |
categories = { | |
"STEM": [ | |
"physics", | |
"chemistry", | |
"biology", | |
"computer science", | |
"math", | |
"engineering", | |
], | |
"humanities": ["history", "philosophy", "law"], | |
"social sciences": [ | |
"politics", | |
"culture", | |
"economics", | |
"geography", | |
"psychology", | |
], | |
"other": ["other", "business", "health"], | |
"Test": ["culture"], | |
} | |
def suite(cls, chat=False): | |
finer_categories = ( | |
pd.Series(cls.subcategories) # noqa # type: ignore | |
.explode() | |
.reset_index() | |
.set_index(0) | |
.groupby(0) | |
.agg(list)["index"] | |
.to_dict() | |
) | |
suite = defaultdict(list) | |
for k, v in cls.categories.items(): | |
for subject in v: | |
suite[k].extend( | |
[ | |
Task( | |
("lukaemon/mmlu", subcategories), | |
metric_name=("sustech/tlem", "mmlu"), | |
input_column=cls.input_column, | |
label_column=cls.label_column, | |
prompt=partial(cls.prompt_mmlu, chat=chat), | |
few_shot=0 if chat else 5, | |
few_shot_from="validation", | |
) | |
for subcategories in finer_categories[subject] | |
] | |
) | |
return suite | |