from dataclasses import dataclass, field from datasets import load_dataset, Dataset from functools import cached_property from tqdm.auto import tqdm from typing import Any, Optional, Protocol, Iterable, Callable import logging import pandas as pd from functools import partial from datasets.utils.logging import disable_progress_bar from .utils import * from evaluate import load from collections import defaultdict import sys # if sys.version_info >= (3, 9): # from functools import cache # else: # from functools import lru_cache as cache disable_progress_bar() def mt_bench_prompt(example): judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. The Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response." judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Your task is to act as an impartial judge and evaluate the quality of the responses provided by an 'assistant' role in the displayed conversation. Your evaluation should focus on the helpfulness, relevance, accuracy, depth, creativity, language fluency, clarity, and level of detail in the assistant's responses. Please note that the evaluation should not consider the user's questions or the overall conversation, but solely the quality of the assistant's replies." multi_prompt = "You evaluation should focus on the assistant's answer to the second user question." ref_prompt = "In the conversation, you will encounter system messages labeled 'Reference Answer' followed by the assistant's response. Your task is to evaluate the quality of the assistant's response by comparing it with the reference answer." json_prompt = 'You must rate the response on a scale of 1 to 10 in JSON format, for example: {"rating": 5}.' prompt_list = [judge_prompt] conversations = example["conversation"] if example["turn"] == 2: prompt_list.append(multi_prompt) if example["reference"] is not None: conversations = [] quesiotns = filter(lambda e: e["role"] == "user", example["conversation"]) answers = filter(lambda e: e["role"] == "assistant", example["conversation"]) for q, a, r in zip(quesiotns, answers, example["reference"]): conversations.append(q) conversations.append( {"role": "system", "content": "Reference Answer: " + r} ) conversations.append(a) prompt_list.append(ref_prompt) prompt_list.append(json_prompt) messages = [{"role": "system", "content": " ".join(prompt_list)}] + conversations return messages @dataclass class Task: dataset_name: str | tuple[str, str] = ("gsm8k", "main") split: str = "test" # metrics: list[str] = field(default_factory=list) metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k") input_column: str = "question" label_column: str = "" prompt: Optional[Callable | str] = None few_shot: int = 0 few_shot_from: Optional[str] = None # results: dict[str, Any] = field(default_factory=dict) def __post_init__(self): names = ( [self.dataset_name] if isinstance(self.dataset_name, str) else list(self.dataset_name) ) names[0] = names[0].split("/")[-1] self.name = "-".join(names) + f"-{self.split}" if isinstance(self.prompt, str): self.prompt = lambda example: { self.input_column: self.prompt.format( input_column=example[self.input_column] ) } self.label_column = self.label_column or self.input_column @cached_property def samples(self): return self.dataset[self.input_column] @cached_property def dataset(self): ds = load_dataset( *self.dataset_name if isinstance(self.dataset_name, tuple) else self.dataset_name, # split=self.split, ) test_ds = ds[self.split] if self.prompt is not None: test_ds = test_ds.map(self.prompt) if self.few_shot: if self.few_shot_from is None: for name in ["train", "validation", "val", "dev"]: if name in ds: self.few_shot_from = name break assert self.few_shot_from != self.split shots = ds[self.few_shot_from].select(range(self.few_shot)) if self.prompt is not None: shots = shots.map(self.prompt) shots = shots.map( lambda example: { self.input_column: example[self.input_column] + example[self.label_column], } )[self.input_column] few_shot_prompts = "\n\n".join(shots) test_ds = test_ds.map( lambda example: { self.input_column: few_shot_prompts + "\n\n" + example[self.input_column], } ) return test_ds @cached_property def metric(self): metric = ( load(self.metric_name) if isinstance(self.metric_name, str) else load(*self.metric_name) ) return metric # @cache def run( self, pipeline, ): if (outputs := pipeline(self.samples)) is None: logging.warning("pipeline returns None") return self.outputs = outputs try: result = self.metric._compute( responses=outputs, references=self.dataset[self.label_column] ) except Exception as e: result = self.metric.compute( responses=outputs, references=self.dataset[self.label_column] ) finally: result = outputs # if log: # name = name or pipeline.__name__ # self.results[name] = result return result def multichoice(responses: Any, references: list[str]): if isinstance(responses[0], str): responses = [extract_choice(response) for response in responses] else: responses = decode_choice(responses) return responses, references def multichoice_zh(responses: Any, references: list[str]): if isinstance(responses[0], str): responses = [extract_choice_zh(response) for response in responses] else: responses = decode_choice(responses) return responses, references class Metrics: cmmlu = multichoice_zh mmlu = multichoice def gsm8k(responses: list[str], answers: list[str | int]): # scores = [] # for response, answer in zip(responses, answers): # pred = extract_numeric(response) # gold = extract_numeric(answer) if isinstance(answer, str) else str(answer) # scores.append(1.0 * (pred == gold)) responses = [extract_numeric(response) for response in responses] answers = [ extract_numeric(answer) if isinstance(answer, str) else str(answer) for answer in answers ] return responses, answers def MATH(responses: list[str], answers: list[str]): scores = [] for response, answer in zip(responses, answers): indices = [pos for pos, char in enumerate(response) if char == "$"] if len(indices) <= 2: scores.append(0) continue else: result = response[indices[-2] + 1 : indices[-1]] gold = get_answer(answer) scores.append(1.0 * is_equiv(result, gold)) return scores def math23k(responses: list[str], answers: list[str]): scores = [] for response, answer in zip(responses, answers): pred = extract_numeric(response, pattern=NUMERIC_IN_ZH) gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH) scores.append(1.0 * (pred == gold)) return scores class CMMLU: input_column = "prompt" label_column = "Answer" def prompt_cmmlu(example, chat=False): prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:" prompt = prefix + example["Question"] for choice in list("ABCD"): prompt += f"\n{choice}. {example[choice]}" prompt += "\n答案:" return {"prompt": prompt} subcategories = { "agronomy": ["other"], "anatomy": ["biology"], "ancient_chinese": ["linguistics", "china specific"], "arts": ["arts"], "astronomy": ["physics"], "business_ethics": ["business"], "chinese_civil_service_exam": ["politics", "china specific"], "chinese_driving_rule": ["other", "china specific"], "chinese_food_culture": ["culture", "china specific"], "chinese_foreign_policy": ["politics", "china specific"], "chinese_history": ["history", "china specific"], "chinese_literature": ["literature", "china specific"], "chinese_teacher_qualification": ["education", "china specific"], "college_actuarial_science": ["math"], "college_education": ["education"], "college_engineering_hydrology": ["engineering"], "college_law": ["law"], "college_mathematics": ["math"], "college_medical_statistics": ["statistics"], "clinical_knowledge": ["other"], "college_medicine": ["other"], "computer_science": ["computer science"], "computer_security": ["other"], "conceptual_physics": ["physics"], "construction_project_management": ["other", "china specific"], "economics": ["economics"], "education": ["education"], "elementary_chinese": ["linguistics", "china specific"], "elementary_commonsense": ["other", "china specific"], "elementary_information_and_technology": ["other"], "electrical_engineering": ["engineering"], "elementary_mathematics": ["math"], "ethnology": ["culture", "china specific"], "food_science": ["other"], "genetics": ["biology"], "global_facts": ["global"], "high_school_biology": ["biology"], "high_school_chemistry": ["chemistry"], "high_school_geography": ["geography"], "high_school_mathematics": ["math"], "high_school_physics": ["physics"], "high_school_politics": ["politics", "china specific"], "human_sexuality": ["other"], "international_law": ["law"], "journalism": ["sociology"], "jurisprudence": ["law"], "legal_and_moral_basis": ["other"], "logical": ["philosophy"], "machine_learning": ["computer science"], "management": ["business"], "marketing": ["business"], "marxist_theory": ["philosophy"], "modern_chinese": ["linguistics", "china specific"], "nutrition": ["other"], "philosophy": ["philosophy"], "professional_accounting": ["business"], "professional_law": ["law"], "professional_medicine": ["other"], "professional_psychology": ["psychology"], "public_relations": ["politics"], "security_study": ["politics"], "sociology": ["culture"], "sports_science": ["other"], "traditional_chinese_medicine": ["other", "china specific"], "virology": ["biology"], "world_history": ["history"], "world_religions": ["global"], } categories = { "STEM": [ "physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics", ], "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], "Social Science": [ "linguistics", "business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology", ], "Other": ["other"], "China specific": ["china specific"], "Test": ["computer science"], } @classmethod def suite(cls, chat=False): finer_categories = ( pd.Series(cls.subcategories) # noqa # type: ignore .explode() .reset_index() .set_index(0) .groupby(0) .agg(list)["index"] .to_dict() ) suite = defaultdict(list) cls.categories["all"] = list(finer_categories.keys()) for k, v in cls.categories.items(): for subject in v: suite[k].extend( [ Task( ("haonan-li/cmmlu", subcategories), metric_name=("sustech/tlem", "cmmlu"), input_column=cls.input_column, label_column=cls.label_column, prompt=partial(cls.prompt_cmmlu, chat=chat), few_shot=0 if chat else 5, few_shot_from="dev", ) for subcategories in finer_categories[subject] ] ) return suite class MMLU: input_column = "prompt" label_column = "target" @classmethod def prompt_mmlu(cls, example, chat=False): prefix = ( "The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" if chat else "Question: " ) prompt = prefix + example["input"] for choice in list("ABCD"): prompt += f"\n{choice}. {example[choice]}" prompt += "\nAnswer:" return {"prompt": prompt} subcategories = { "abstract_algebra": ["math"], "anatomy": ["health"], "astronomy": ["physics"], "business_ethics": ["business"], "clinical_knowledge": ["health"], "college_biology": ["biology"], "college_chemistry": ["chemistry"], "college_computer_science": ["computer science"], "college_mathematics": ["math"], "college_medicine": ["health"], "college_physics": ["physics"], "computer_security": ["computer science"], "conceptual_physics": ["physics"], "econometrics": ["economics"], "electrical_engineering": ["engineering"], "elementary_mathematics": ["math"], "formal_logic": ["philosophy"], "global_facts": ["other"], "high_school_biology": ["biology"], "high_school_chemistry": ["chemistry"], "high_school_computer_science": ["computer science"], "high_school_european_history": ["history"], "high_school_geography": ["geography"], "high_school_government_and_politics": ["politics"], "high_school_macroeconomics": ["economics"], "high_school_mathematics": ["math"], "high_school_microeconomics": ["economics"], "high_school_physics": ["physics"], "high_school_psychology": ["psychology"], "high_school_statistics": ["math"], "high_school_us_history": ["history"], "high_school_world_history": ["history"], "human_aging": ["health"], "human_sexuality": ["culture"], "international_law": ["law"], "jurisprudence": ["law"], "logical_fallacies": ["philosophy"], "machine_learning": ["computer science"], "management": ["business"], "marketing": ["business"], "medical_genetics": ["health"], "miscellaneous": ["other"], "moral_disputes": ["philosophy"], "moral_scenarios": ["philosophy"], "nutrition": ["health"], "philosophy": ["philosophy"], "prehistory": ["history"], "professional_accounting": ["other"], "professional_law": ["law"], "professional_medicine": ["health"], "professional_psychology": ["psychology"], "public_relations": ["politics"], "security_studies": ["politics"], "sociology": ["culture"], "us_foreign_policy": ["politics"], "virology": ["health"], "world_religions": ["philosophy"], } categories = { "STEM": [ "physics", "chemistry", "biology", "computer science", "math", "engineering", ], "humanities": ["history", "philosophy", "law"], "social sciences": [ "politics", "culture", "economics", "geography", "psychology", ], "other": ["other", "business", "health"], } @classmethod def suite(cls, chat=False): finer_categories = ( pd.Series(cls.subcategories) # noqa # type: ignore .explode() .reset_index() .set_index(0) .groupby(0) .agg(list)["index"] .to_dict() ) suite = defaultdict(list) cls.categories["all"] = list(finer_categories.keys()) for k, v in cls.categories.items(): for subject in v: suite[k].extend( [ Task( ("lukaemon/mmlu", subcategories), metric_name=("sustech/tlem", "mmlu"), input_column=cls.input_column, label_column=cls.label_column, prompt=partial(cls.prompt_mmlu, chat=chat), few_shot=0 if chat else 5, few_shot_from="validation", ) for subcategories in finer_categories[subject] ] ) return suite