tlem / tasks.py
facat's picture
update mt_bench
845a45a
raw
history blame
No virus
18.4 kB
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable
import logging
import pandas as pd
from functools import partial
from datasets.utils.logging import disable_progress_bar
from .utils import *
from evaluate import load
from collections import defaultdict
import sys
# if sys.version_info >= (3, 9):
# from functools import cache
# else:
# from functools import lru_cache as cache
disable_progress_bar()
def mt_bench_prompt(example):
judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. The Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response."
judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Your task is to act as an impartial judge and evaluate the quality of the responses provided by an 'assistant' role in the displayed conversation. Your evaluation should focus on the helpfulness, relevance, accuracy, depth, creativity, language fluency, clarity, and level of detail in the assistant's responses. Please note that the evaluation should not consider the user's questions or the overall conversation, but solely the quality of the assistant's replies."
multi_prompt = "You evaluation should focus on the assistant's answer to the second user question."
ref_prompt = "In the conversation, you will encounter system messages labeled 'Reference Answer' followed by the assistant's response. Your task is to evaluate the quality of the assistant's response by comparing it with the reference answer."
json_prompt = 'You must rate the response on a scale of 1 to 10 in JSON format, for example: {"rating": 5}.'
prompt_list = [judge_prompt]
conversations = example["conversation"]
if example["turn"] == 2:
prompt_list.append(multi_prompt)
if example["reference"] is not None:
conversations = []
quesiotns = filter(lambda e: e["role"] == "user", example["conversation"])
answers = filter(lambda e: e["role"] == "assistant", example["conversation"])
for q, a, r in zip(quesiotns, answers, example["reference"]):
conversations.append(q)
conversations.append(
{"role": "system", "content": "Reference Answer: " + r}
)
conversations.append(a)
prompt_list.append(ref_prompt)
prompt_list.append(json_prompt)
messages = [{"role": "system", "content": " ".join(prompt_list)}] + conversations
return messages
@dataclass
class Task:
dataset_name: str | tuple[str, str] = ("gsm8k", "main")
split: str = "test"
# metrics: list[str] = field(default_factory=list)
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
input_column: str = "question"
label_column: str = ""
prompt: Optional[Callable | str] = None
few_shot: int = 0
few_shot_from: Optional[str] = None
# results: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
names = (
[self.dataset_name]
if isinstance(self.dataset_name, str)
else list(self.dataset_name)
)
names[0] = names[0].split("/")[-1]
self.name = "-".join(names) + f"-{self.split}"
if isinstance(self.prompt, str):
self.prompt = lambda example: {
self.input_column: self.prompt.format(
input_column=example[self.input_column]
)
}
self.label_column = self.label_column or self.input_column
@cached_property
def samples(self):
return self.dataset[self.input_column]
@cached_property
def dataset(self):
ds = load_dataset(
*self.dataset_name
if isinstance(self.dataset_name, tuple)
else self.dataset_name,
# split=self.split,
)
test_ds = ds[self.split]
if self.prompt is not None:
test_ds = test_ds.map(self.prompt)
if self.few_shot:
if self.few_shot_from is None:
for name in ["train", "validation", "val", "dev"]:
if name in ds:
self.few_shot_from = name
break
assert self.few_shot_from != self.split
shots = ds[self.few_shot_from].select(range(self.few_shot))
if self.prompt is not None:
shots = shots.map(self.prompt)
shots = shots.map(
lambda example: {
self.input_column: example[self.input_column]
+ example[self.label_column],
}
)[self.input_column]
few_shot_prompts = "\n\n".join(shots)
test_ds = test_ds.map(
lambda example: {
self.input_column: few_shot_prompts
+ "\n\n"
+ example[self.input_column],
}
)
return test_ds
@cached_property
def metric(self):
metric = (
load(self.metric_name)
if isinstance(self.metric_name, str)
else load(*self.metric_name)
)
return metric
# @cache
def run(
self,
pipeline,
):
if (outputs := pipeline(self.samples)) is None:
logging.warning("pipeline returns None")
return
self.outputs = outputs
try:
result = self.metric._compute(
responses=outputs, references=self.dataset[self.label_column]
)
except Exception as e:
result = self.metric.compute(
responses=outputs, references=self.dataset[self.label_column]
)
finally:
result = outputs
# if log:
# name = name or pipeline.__name__
# self.results[name] = result
return result
def multichoice(responses: Any, references: list[str]):
if isinstance(responses[0], str):
responses = [extract_choice(response) for response in responses]
else:
responses = decode_choice(responses)
return responses, references
def multichoice_zh(responses: Any, references: list[str]):
if isinstance(responses[0], str):
responses = [extract_choice_zh(response) for response in responses]
else:
responses = decode_choice(responses)
return responses, references
class Metrics:
cmmlu = multichoice_zh
mmlu = multichoice
def gsm8k(responses: list[str], answers: list[str | int]):
# scores = []
# for response, answer in zip(responses, answers):
# pred = extract_numeric(response)
# gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
# scores.append(1.0 * (pred == gold))
responses = [extract_numeric(response) for response in responses]
answers = [
extract_numeric(answer) if isinstance(answer, str) else str(answer)
for answer in answers
]
return responses, answers
def MATH(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
indices = [pos for pos, char in enumerate(response) if char == "$"]
if len(indices) <= 2:
scores.append(0)
continue
else:
result = response[indices[-2] + 1 : indices[-1]]
gold = get_answer(answer)
scores.append(1.0 * is_equiv(result, gold))
return scores
def math23k(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
scores.append(1.0 * (pred == gold))
return scores
class CMMLU:
input_column = "prompt"
label_column = "Answer"
def prompt_cmmlu(example, chat=False):
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:"
prompt = prefix + example["Question"]
for choice in list("ABCD"):
prompt += f"\n{choice}. {example[choice]}"
prompt += "\n答案:"
return {"prompt": prompt}
subcategories = {
"agronomy": ["other"],
"anatomy": ["biology"],
"ancient_chinese": ["linguistics", "china specific"],
"arts": ["arts"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"chinese_civil_service_exam": ["politics", "china specific"],
"chinese_driving_rule": ["other", "china specific"],
"chinese_food_culture": ["culture", "china specific"],
"chinese_foreign_policy": ["politics", "china specific"],
"chinese_history": ["history", "china specific"],
"chinese_literature": ["literature", "china specific"],
"chinese_teacher_qualification": ["education", "china specific"],
"college_actuarial_science": ["math"],
"college_education": ["education"],
"college_engineering_hydrology": ["engineering"],
"college_law": ["law"],
"college_mathematics": ["math"],
"college_medical_statistics": ["statistics"],
"clinical_knowledge": ["other"],
"college_medicine": ["other"],
"computer_science": ["computer science"],
"computer_security": ["other"],
"conceptual_physics": ["physics"],
"construction_project_management": ["other", "china specific"],
"economics": ["economics"],
"education": ["education"],
"elementary_chinese": ["linguistics", "china specific"],
"elementary_commonsense": ["other", "china specific"],
"elementary_information_and_technology": ["other"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"ethnology": ["culture", "china specific"],
"food_science": ["other"],
"genetics": ["biology"],
"global_facts": ["global"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_geography": ["geography"],
"high_school_mathematics": ["math"],
"high_school_physics": ["physics"],
"high_school_politics": ["politics", "china specific"],
"human_sexuality": ["other"],
"international_law": ["law"],
"journalism": ["sociology"],
"jurisprudence": ["law"],
"legal_and_moral_basis": ["other"],
"logical": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"marxist_theory": ["philosophy"],
"modern_chinese": ["linguistics", "china specific"],
"nutrition": ["other"],
"philosophy": ["philosophy"],
"professional_accounting": ["business"],
"professional_law": ["law"],
"professional_medicine": ["other"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_study": ["politics"],
"sociology": ["culture"],
"sports_science": ["other"],
"traditional_chinese_medicine": ["other", "china specific"],
"virology": ["biology"],
"world_history": ["history"],
"world_religions": ["global"],
}
categories = {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
"statistics",
],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": [
"linguistics",
"business",
"politics",
"culture",
"economics",
"geography",
"psychology",
"education",
"sociology",
],
"Other": ["other"],
"China specific": ["china specific"],
"Test": ["computer science"],
}
@classmethod
def suite(cls, chat=False):
finer_categories = (
pd.Series(cls.subcategories) # noqa # type: ignore
.explode()
.reset_index()
.set_index(0)
.groupby(0)
.agg(list)["index"]
.to_dict()
)
suite = defaultdict(list)
cls.categories["all"] = list(finer_categories.keys())
for k, v in cls.categories.items():
for subject in v:
suite[k].extend(
[
Task(
("haonan-li/cmmlu", subcategories),
metric_name=("sustech/tlem", "cmmlu"),
input_column=cls.input_column,
label_column=cls.label_column,
prompt=partial(cls.prompt_cmmlu, chat=chat),
few_shot=0 if chat else 5,
few_shot_from="dev",
)
for subcategories in finer_categories[subject]
]
)
return suite
class MMLU:
input_column = "prompt"
label_column = "target"
@classmethod
def prompt_mmlu(cls, example, chat=False):
prefix = (
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n"
if chat
else "Question: "
)
prompt = prefix + example["input"]
for choice in list("ABCD"):
prompt += f"\n{choice}. {example[choice]}"
prompt += "\nAnswer:"
return {"prompt": prompt}
subcategories = {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}
categories = {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
],
"humanities": ["history", "philosophy", "law"],
"social sciences": [
"politics",
"culture",
"economics",
"geography",
"psychology",
],
"other": ["other", "business", "health"],
}
@classmethod
def suite(cls, chat=False):
finer_categories = (
pd.Series(cls.subcategories) # noqa # type: ignore
.explode()
.reset_index()
.set_index(0)
.groupby(0)
.agg(list)["index"]
.to_dict()
)
suite = defaultdict(list)
cls.categories["all"] = list(finer_categories.keys())
for k, v in cls.categories.items():
for subject in v:
suite[k].extend(
[
Task(
("lukaemon/mmlu", subcategories),
metric_name=("sustech/tlem", "mmlu"),
input_column=cls.input_column,
label_column=cls.label_column,
prompt=partial(cls.prompt_mmlu, chat=chat),
few_shot=0 if chat else 5,
few_shot_from="validation",
)
for subcategories in finer_categories[subject]
]
)
return suite