JeffYang52415
commited on
feat: add bbh/mmlu parser
Browse files- llmdataparser/bbh_parser.py +110 -0
- llmdataparser/mmlu_parser.py +394 -53
- llmdataparser/prompts.py +114 -3
- pyproject.toml +4 -0
- tests/test_bbh_parser.py +160 -0
- tests/test_mmlu_parser.py +220 -0
llmdataparser/bbh_parser.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, ClassVar
|
3 |
+
|
4 |
+
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
|
5 |
+
from llmdataparser.prompts import BBH_SYSTEM_PROMPT # You'll need to create this
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
9 |
+
class BBHParseEntry(HuggingFaceParseEntry):
|
10 |
+
"""Custom entry class for BBH (Big Bench Hard), with fields specific to this dataset."""
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def create(
|
14 |
+
cls,
|
15 |
+
prompt: str,
|
16 |
+
answer: str,
|
17 |
+
raw_question: str,
|
18 |
+
raw_answer: str,
|
19 |
+
task_name: str,
|
20 |
+
) -> "BBHParseEntry":
|
21 |
+
return cls(
|
22 |
+
prompt=prompt,
|
23 |
+
answer=answer,
|
24 |
+
raw_question=raw_question,
|
25 |
+
raw_answer=raw_answer,
|
26 |
+
task_name=task_name,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class BBHDatasetParser(HuggingFaceDatasetParser[BBHParseEntry]):
|
31 |
+
"""Parser for the Big Bench Hard dataset."""
|
32 |
+
|
33 |
+
_data_source: ClassVar[str] = "lukaemon/bbh"
|
34 |
+
_task_names: ClassVar[list[str]] = [
|
35 |
+
"boolean_expressions",
|
36 |
+
"causal_judgement",
|
37 |
+
"date_understanding",
|
38 |
+
"disambiguation_qa",
|
39 |
+
"dyck_languages",
|
40 |
+
"formal_fallacies",
|
41 |
+
"geometric_shapes",
|
42 |
+
"hyperbaton",
|
43 |
+
"logical_deduction_five_objects",
|
44 |
+
"logical_deduction_seven_objects",
|
45 |
+
"logical_deduction_three_objects",
|
46 |
+
"movie_recommendation",
|
47 |
+
"multistep_arithmetic_two",
|
48 |
+
"navigate",
|
49 |
+
"object_counting",
|
50 |
+
"penguins_in_a_table",
|
51 |
+
"reasoning_about_colored_objects",
|
52 |
+
"ruin_names",
|
53 |
+
"salient_translation_error_detection",
|
54 |
+
"snarks",
|
55 |
+
"sports_understanding",
|
56 |
+
"temporal_sequences",
|
57 |
+
"tracking_shuffled_objects_five_objects",
|
58 |
+
"tracking_shuffled_objects_seven_objects",
|
59 |
+
"tracking_shuffled_objects_three_objects",
|
60 |
+
"web_of_lies",
|
61 |
+
"word_sorting",
|
62 |
+
]
|
63 |
+
_default_task: ClassVar[str] = "reasoning_about_colored_objects"
|
64 |
+
_default_system_prompt: ClassVar[str] = BBH_SYSTEM_PROMPT
|
65 |
+
|
66 |
+
def process_entry(
|
67 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
68 |
+
) -> BBHParseEntry:
|
69 |
+
"""Process a single BBH entry."""
|
70 |
+
raw_question = row["input"]
|
71 |
+
raw_answer = row["target"]
|
72 |
+
|
73 |
+
# Remove parentheses from the answer
|
74 |
+
clean_answer = raw_answer.strip("()")
|
75 |
+
|
76 |
+
# Combine system prompt with the question
|
77 |
+
prompt = f"{self._system_prompt}\n\n{raw_question}"
|
78 |
+
|
79 |
+
# Use task_name if provided, otherwise use default
|
80 |
+
task = task_name or self._get_current_task(row)
|
81 |
+
|
82 |
+
return BBHParseEntry.create(
|
83 |
+
prompt=prompt,
|
84 |
+
answer=clean_answer,
|
85 |
+
raw_question=raw_question,
|
86 |
+
raw_answer=raw_answer,
|
87 |
+
task_name=task,
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
# Example usage
|
93 |
+
parser = BBHDatasetParser()
|
94 |
+
|
95 |
+
# Load the dataset with a specific task
|
96 |
+
parser.load(task_name="reasoning_about_colored_objects")
|
97 |
+
|
98 |
+
# Parse all splits
|
99 |
+
parser.parse()
|
100 |
+
|
101 |
+
# Get parsed data
|
102 |
+
parsed_data = parser.get_parsed_data
|
103 |
+
|
104 |
+
# Print example entry
|
105 |
+
if parsed_data:
|
106 |
+
example = parsed_data[0]
|
107 |
+
print("\nExample parsed entry:")
|
108 |
+
print(f"Task: {example.task_name}")
|
109 |
+
print(f"Question: {example.raw_question}")
|
110 |
+
print(f"Answer: {example.answer}")
|
llmdataparser/mmlu_parser.py
CHANGED
@@ -1,81 +1,422 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import Any
|
3 |
|
4 |
-
from llmdataparser.base_parser import HuggingFaceDatasetParser,
|
5 |
-
from llmdataparser.prompts import MMLU_SYSTEM_PROMPT
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
@dataclass(frozen=True)
|
9 |
-
class MMLUParseEntry(ParseEntry):
|
10 |
-
"""
|
11 |
-
Custom entry class for MMLU, with fields specific to this dataset parser.
|
12 |
-
"""
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
@classmethod
|
18 |
-
def create(
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
raise ValueError(
|
21 |
-
f"Invalid answer_letter '{
|
22 |
)
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class MMLUDatasetParser(HuggingFaceDatasetParser[MMLUParseEntry]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
_data_source = "cais/mmlu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
self.parsed_data: list[MMLUParseEntry] = []
|
32 |
-
self.task_names: list[str] = []
|
33 |
-
self.subject_list: set[str] = set()
|
34 |
-
self.system_prompt: str = system_prompt
|
35 |
-
super().__init__()
|
36 |
-
|
37 |
-
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
|
38 |
-
self.parsed_data.clear()
|
39 |
-
if self.raw_data is None:
|
40 |
-
raise ValueError("No data loaded. Please load the dataset first.")
|
41 |
-
|
42 |
-
if split_names is None:
|
43 |
-
split_names = self.task_names
|
44 |
-
elif isinstance(split_names, str):
|
45 |
-
split_names = [split_names]
|
46 |
-
|
47 |
-
for split_name in split_names:
|
48 |
-
if split_name not in self.task_names:
|
49 |
-
raise ValueError(f"Task '{split_name}' not found in the dataset.")
|
50 |
-
|
51 |
-
dataset_split = self.raw_data[split_name]
|
52 |
-
for index, entry in enumerate(dataset_split, start=1):
|
53 |
-
data_entry = self.process_entry(entry, **kwargs)
|
54 |
-
self._parsed_data.append(data_entry)
|
55 |
-
self.subject_list.add(entry.get("subject", "Unknown"))
|
56 |
-
print(f"Parsed {index} data points from task '{split_name}'.")
|
57 |
-
|
58 |
-
print(
|
59 |
-
f"Number of subjects: {len(self.subject_list)}. "
|
60 |
-
"For more details, please check the `self.subject_list` attribute."
|
61 |
)
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
Generate a prompt and expected answer from the given row.
|
66 |
|
67 |
Args:
|
68 |
-
row (dict[str, Any]): A data point to be formatted
|
|
|
69 |
|
70 |
Returns:
|
71 |
MMLUParseEntry: The formatted entry object.
|
72 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
choices = "\n".join(
|
74 |
-
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(
|
75 |
)
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
-
answer_letter = chr(65 + row["answer"]) # Convert index to 'A', 'B', 'C', 'D'
|
80 |
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from typing import Any, Final
|
3 |
|
4 |
+
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
|
5 |
+
from llmdataparser.prompts import MMLU_PRO_SYSTEM_PROMPT, MMLU_SYSTEM_PROMPT
|
6 |
|
7 |
+
MMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
8 |
+
MMLU_PRO_VALID_ANSWERS: Final[set[str]] = {
|
9 |
+
"A",
|
10 |
+
"B",
|
11 |
+
"C",
|
12 |
+
"D",
|
13 |
+
"E",
|
14 |
+
"F",
|
15 |
+
"G",
|
16 |
+
"H",
|
17 |
+
"I",
|
18 |
+
"J",
|
19 |
+
}
|
20 |
+
MMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_VALID_ANSWERS))
|
21 |
+
MMLU_PRO_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_PRO_VALID_ANSWERS))
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
25 |
+
class MMLUParseEntry(HuggingFaceParseEntry):
|
26 |
+
"""Custom entry class for MMLU, with fields specific to this dataset parser."""
|
27 |
+
|
28 |
+
raw_choices: list[str]
|
29 |
+
task_name: str
|
30 |
|
31 |
@classmethod
|
32 |
+
def create(
|
33 |
+
cls,
|
34 |
+
prompt: str,
|
35 |
+
answer: str,
|
36 |
+
raw_question: str,
|
37 |
+
raw_choices: list[str],
|
38 |
+
raw_answer: str,
|
39 |
+
task_name: str,
|
40 |
+
) -> "MMLUParseEntry":
|
41 |
+
if answer not in MMLU_VALID_ANSWERS:
|
42 |
raise ValueError(
|
43 |
+
f"Invalid answer_letter '{answer}'; must be one of {MMLU_VALID_ANSWER_STR}"
|
44 |
)
|
45 |
+
if not task_name:
|
46 |
+
raise ValueError("Task name cannot be empty")
|
47 |
+
return cls(
|
48 |
+
prompt=prompt,
|
49 |
+
answer=answer,
|
50 |
+
raw_question=raw_question,
|
51 |
+
raw_answer=raw_answer,
|
52 |
+
raw_choices=raw_choices,
|
53 |
+
task_name=task_name,
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
58 |
+
class MMLUProParseEntry(HuggingFaceParseEntry):
|
59 |
+
"""Custom entry class for MMLU, with fields specific to this dataset parser."""
|
60 |
+
|
61 |
+
raw_choices: list[str]
|
62 |
+
task_name: str
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def create(
|
66 |
+
cls,
|
67 |
+
prompt: str,
|
68 |
+
answer: str,
|
69 |
+
raw_question: str,
|
70 |
+
raw_choices: list[str],
|
71 |
+
raw_answer: str,
|
72 |
+
task_name: str,
|
73 |
+
) -> "MMLUProParseEntry":
|
74 |
+
if answer not in MMLU_PRO_VALID_ANSWERS:
|
75 |
+
raise ValueError(
|
76 |
+
f"Invalid answer_letter '{answer}'; must be one of {MMLU_PRO_VALID_ANSWER_STR}"
|
77 |
+
)
|
78 |
+
if not task_name:
|
79 |
+
raise ValueError("Task name cannot be empty")
|
80 |
+
return cls(
|
81 |
+
prompt=prompt,
|
82 |
+
answer=answer,
|
83 |
+
raw_question=raw_question,
|
84 |
+
raw_choices=raw_choices,
|
85 |
+
raw_answer=raw_answer,
|
86 |
+
task_name=task_name,
|
87 |
+
)
|
88 |
|
89 |
|
90 |
class MMLUDatasetParser(HuggingFaceDatasetParser[MMLUParseEntry]):
|
91 |
+
"""Base class for MMLU dataset parsers with common functionality."""
|
92 |
+
|
93 |
+
_default_system_prompt = MMLU_SYSTEM_PROMPT
|
94 |
+
|
95 |
+
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
|
96 |
+
"""Get the task name from the data entry or default task name."""
|
97 |
+
task_name = data_entry.get("subject")
|
98 |
+
return task_name if task_name else (self._current_task or self._default_task)
|
99 |
+
|
100 |
+
def process_entry(
|
101 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
102 |
+
) -> MMLUParseEntry:
|
103 |
+
"""
|
104 |
+
Generate a prompt and expected answer from the given row.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
row: A data point to be formatted.
|
108 |
+
task_name: Optional task name for the entry.
|
109 |
+
**kwargs: Additional keyword arguments.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
MMLUParseEntry: The formatted entry object.
|
113 |
+
"""
|
114 |
+
task = task_name or self._get_current_task(row)
|
115 |
+
# Ensure task is not None
|
116 |
+
final_task = task or self._default_task
|
117 |
+
|
118 |
+
choices = "\n".join(
|
119 |
+
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(row["choices"])
|
120 |
+
)
|
121 |
+
raw_question = row["question"]
|
122 |
+
raw_choices = row["choices"]
|
123 |
+
raw_answer = str(row["answer"]) # Ensure raw_answer is a string
|
124 |
+
|
125 |
+
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
|
126 |
+
answer_letter = chr(65 + int(raw_answer)) # Convert index to 'A', 'B', 'C', 'D'
|
127 |
+
|
128 |
+
return MMLUParseEntry.create(
|
129 |
+
prompt=prompt,
|
130 |
+
answer=answer_letter,
|
131 |
+
raw_question=raw_question,
|
132 |
+
raw_choices=raw_choices,
|
133 |
+
raw_answer=raw_answer,
|
134 |
+
task_name=final_task,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
class BaseMMLUDatasetParser(MMLUDatasetParser):
|
139 |
+
"""Parser for the original MMLU dataset."""
|
140 |
+
|
141 |
_data_source = "cais/mmlu"
|
142 |
+
_default_task = "all"
|
143 |
+
_task_names = [
|
144 |
+
"abstract_algebra",
|
145 |
+
"anatomy",
|
146 |
+
"astronomy",
|
147 |
+
"business_ethics",
|
148 |
+
"clinical_knowledge",
|
149 |
+
"college_biology",
|
150 |
+
"college_chemistry",
|
151 |
+
"college_computer_science",
|
152 |
+
"college_mathematics",
|
153 |
+
"college_medicine",
|
154 |
+
"college_physics",
|
155 |
+
"computer_security",
|
156 |
+
"conceptual_physics",
|
157 |
+
"econometrics",
|
158 |
+
"electrical_engineering",
|
159 |
+
"elementary_mathematics",
|
160 |
+
"formal_logic",
|
161 |
+
"global_facts",
|
162 |
+
"high_school_biology",
|
163 |
+
"high_school_chemistry",
|
164 |
+
"high_school_computer_science",
|
165 |
+
"high_school_european_history",
|
166 |
+
"high_school_geography",
|
167 |
+
"high_school_government_and_politics",
|
168 |
+
"high_school_macroeconomics",
|
169 |
+
"high_school_mathematics",
|
170 |
+
"high_school_microeconomics",
|
171 |
+
"high_school_physics",
|
172 |
+
"high_school_psychology",
|
173 |
+
"high_school_statistics",
|
174 |
+
"high_school_us_history",
|
175 |
+
"high_school_world_history",
|
176 |
+
"human_aging",
|
177 |
+
"human_sexuality",
|
178 |
+
"international_law",
|
179 |
+
"jurisprudence",
|
180 |
+
"logical_fallacies",
|
181 |
+
"machine_learning",
|
182 |
+
"management",
|
183 |
+
"marketing",
|
184 |
+
"medical_genetics",
|
185 |
+
"miscellaneous",
|
186 |
+
"moral_disputes",
|
187 |
+
"moral_scenarios",
|
188 |
+
"nutrition",
|
189 |
+
"philosophy",
|
190 |
+
"prehistory",
|
191 |
+
"professional_accounting",
|
192 |
+
"professional_law",
|
193 |
+
"professional_medicine",
|
194 |
+
"professional_psychology",
|
195 |
+
"public_relations",
|
196 |
+
"security_studies",
|
197 |
+
"sociology",
|
198 |
+
"us_foreign_policy",
|
199 |
+
"virology",
|
200 |
+
"world_religions",
|
201 |
+
]
|
202 |
+
|
203 |
+
|
204 |
+
class MMLUReduxDatasetParser(MMLUDatasetParser):
|
205 |
+
"""Parser for the MMLU Redux dataset."""
|
206 |
+
|
207 |
+
_data_source = "edinburgh-dawg/mmlu-redux"
|
208 |
+
_default_task = "anatomy"
|
209 |
+
_task_names = [
|
210 |
+
"anatomy",
|
211 |
+
"astronomy",
|
212 |
+
"business_ethics",
|
213 |
+
"clinical_knowledge",
|
214 |
+
"college_chemistry",
|
215 |
+
"college_computer_science",
|
216 |
+
"college_mathematics",
|
217 |
+
"college_medicine",
|
218 |
+
"college_physics",
|
219 |
+
"conceptual_physics",
|
220 |
+
"econometrics",
|
221 |
+
"electrical_engineering",
|
222 |
+
"formal_logic",
|
223 |
+
"global_facts",
|
224 |
+
"high_school_chemistry",
|
225 |
+
"high_school_geography",
|
226 |
+
"high_school_macroeconomics",
|
227 |
+
"high_school_mathematics",
|
228 |
+
"high_school_physics",
|
229 |
+
"high_school_statistics",
|
230 |
+
"high_school_us_history",
|
231 |
+
"human_aging",
|
232 |
+
"logical_fallacies",
|
233 |
+
"machine_learning",
|
234 |
+
"miscellaneous",
|
235 |
+
"philosophy",
|
236 |
+
"professional_accounting",
|
237 |
+
"professional_law",
|
238 |
+
"public_relations",
|
239 |
+
"virology",
|
240 |
+
]
|
241 |
+
|
242 |
+
|
243 |
+
class TMMLUPlusDatasetParser(MMLUDatasetParser):
|
244 |
+
"""Parser for the TMMLU+ dataset."""
|
245 |
+
|
246 |
+
_data_source = "ikala/tmmluplus"
|
247 |
+
_default_task = "taiwanese_hokkien"
|
248 |
+
_task_names = [
|
249 |
+
"engineering_math",
|
250 |
+
"dentistry",
|
251 |
+
"traditional_chinese_medicine_clinical_medicine",
|
252 |
+
"clinical_psychology",
|
253 |
+
"technical",
|
254 |
+
"culinary_skills",
|
255 |
+
"mechanical",
|
256 |
+
"logic_reasoning",
|
257 |
+
"real_estate",
|
258 |
+
"general_principles_of_law",
|
259 |
+
"finance_banking",
|
260 |
+
"anti_money_laundering",
|
261 |
+
"ttqav2",
|
262 |
+
"marketing_management",
|
263 |
+
"business_management",
|
264 |
+
"organic_chemistry",
|
265 |
+
"advance_chemistry",
|
266 |
+
"physics",
|
267 |
+
"secondary_physics",
|
268 |
+
"human_behavior",
|
269 |
+
"national_protection",
|
270 |
+
"jce_humanities",
|
271 |
+
"politic_science",
|
272 |
+
"agriculture",
|
273 |
+
"official_document_management",
|
274 |
+
"financial_analysis",
|
275 |
+
"pharmacy",
|
276 |
+
"educational_psychology",
|
277 |
+
"statistics_and_machine_learning",
|
278 |
+
"management_accounting",
|
279 |
+
"introduction_to_law",
|
280 |
+
"computer_science",
|
281 |
+
"veterinary_pathology",
|
282 |
+
"accounting",
|
283 |
+
"fire_science",
|
284 |
+
"optometry",
|
285 |
+
"insurance_studies",
|
286 |
+
"pharmacology",
|
287 |
+
"taxation",
|
288 |
+
"trust_practice",
|
289 |
+
"geography_of_taiwan",
|
290 |
+
"physical_education",
|
291 |
+
"auditing",
|
292 |
+
"administrative_law",
|
293 |
+
"education_(profession_level)",
|
294 |
+
"economics",
|
295 |
+
"veterinary_pharmacology",
|
296 |
+
"nautical_science",
|
297 |
+
"occupational_therapy_for_psychological_disorders",
|
298 |
+
"basic_medical_science",
|
299 |
+
"macroeconomics",
|
300 |
+
"trade",
|
301 |
+
"chinese_language_and_literature",
|
302 |
+
"tve_design",
|
303 |
+
"junior_science_exam",
|
304 |
+
"junior_math_exam",
|
305 |
+
"junior_chinese_exam",
|
306 |
+
"junior_social_studies",
|
307 |
+
"tve_mathematics",
|
308 |
+
"tve_chinese_language",
|
309 |
+
"tve_natural_sciences",
|
310 |
+
"junior_chemistry",
|
311 |
+
"music",
|
312 |
+
"education",
|
313 |
+
"three_principles_of_people",
|
314 |
+
"taiwanese_hokkien",
|
315 |
+
]
|
316 |
+
|
317 |
+
def process_entry(
|
318 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
319 |
+
) -> MMLUParseEntry:
|
320 |
+
"""Process a single TMMLU+ entry."""
|
321 |
+
# Extract choices in order
|
322 |
+
raw_choices = [row["A"], row["B"], row["C"], row["D"]]
|
323 |
+
choices = "\n".join(
|
324 |
+
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
|
325 |
+
)
|
326 |
+
raw_question = row["question"]
|
327 |
+
raw_answer = row["answer"]
|
328 |
+
|
329 |
+
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
|
330 |
+
task = task_name or self._get_current_task(row)
|
331 |
|
332 |
+
return MMLUParseEntry.create(
|
333 |
+
prompt, raw_answer, raw_question, raw_choices, raw_answer, task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
)
|
335 |
|
336 |
+
|
337 |
+
class MMLUProDatasetParser(HuggingFaceDatasetParser[MMLUProParseEntry]):
|
338 |
+
"""Parser for the MMLU Pro dataset."""
|
339 |
+
|
340 |
+
_data_source = "TIGER-Lab/MMLU-Pro"
|
341 |
+
_default_task = "default"
|
342 |
+
_task_names = [
|
343 |
+
"math",
|
344 |
+
"physics",
|
345 |
+
"chemistry",
|
346 |
+
"law",
|
347 |
+
"engineering",
|
348 |
+
"other",
|
349 |
+
"economics",
|
350 |
+
"health",
|
351 |
+
"psychology",
|
352 |
+
"business",
|
353 |
+
"biology",
|
354 |
+
"philosophy",
|
355 |
+
"computer_science",
|
356 |
+
"history",
|
357 |
+
]
|
358 |
+
_default_system_prompt = MMLU_PRO_SYSTEM_PROMPT
|
359 |
+
|
360 |
+
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
|
361 |
+
"""Get the task name from the data entry or default task name."""
|
362 |
+
if data_entry is not None:
|
363 |
+
task_name = data_entry.get("category")
|
364 |
+
if task_name:
|
365 |
+
return task_name
|
366 |
+
return self._current_task or self._default_task
|
367 |
+
|
368 |
+
def process_entry(
|
369 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
370 |
+
) -> MMLUProParseEntry:
|
371 |
"""
|
372 |
Generate a prompt and expected answer from the given row.
|
373 |
|
374 |
Args:
|
375 |
+
row (dict[str, Any]): A data point to be formatted with MMLU Pro specific structure
|
376 |
+
containing 'question', 'options', 'answer', and 'answer_index' keys.
|
377 |
|
378 |
Returns:
|
379 |
MMLUParseEntry: The formatted entry object.
|
380 |
"""
|
381 |
+
task = task_name or self._get_current_task(row)
|
382 |
+
# Ensure task is not None
|
383 |
+
final_task = task or self._default_task
|
384 |
+
|
385 |
+
# Extract choices in order
|
386 |
+
raw_choices = row["options"]
|
387 |
choices = "\n".join(
|
388 |
+
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
|
389 |
)
|
390 |
+
raw_question = row["question"]
|
391 |
+
raw_answer = row["answer"]
|
392 |
+
answer_index = row["answer_index"]
|
393 |
+
|
394 |
+
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
|
395 |
+
answer_letter = chr(
|
396 |
+
65 + answer_index
|
397 |
+
) # Convert index to 'A', 'B', 'C', 'D', etc.
|
398 |
+
|
399 |
+
return MMLUProParseEntry.create(
|
400 |
+
prompt, answer_letter, raw_question, raw_choices, raw_answer, final_task
|
401 |
)
|
|
|
402 |
|
403 |
+
|
404 |
+
if __name__ == "__main__":
|
405 |
+
# Example usage of MMLU Pro parser
|
406 |
+
parser = MMLUProDatasetParser()
|
407 |
+
parser.load()
|
408 |
+
parser.parse()
|
409 |
+
|
410 |
+
# Get parsed data with correct type
|
411 |
+
parsed_data = parser.get_parsed_data
|
412 |
+
|
413 |
+
# Print example entry
|
414 |
+
if parsed_data:
|
415 |
+
example = parsed_data[0]
|
416 |
+
print("\nExample parsed entry:")
|
417 |
+
print(f"Task: {example.task_name}")
|
418 |
+
print(f"Question: {example.raw_question}")
|
419 |
+
print("Choices:")
|
420 |
+
for i, choice in enumerate(example.raw_choices):
|
421 |
+
print(f"{chr(65 + i)}. {choice}")
|
422 |
+
print(f"Correct Answer: {example.answer}")
|
llmdataparser/prompts.py
CHANGED
@@ -3,10 +3,121 @@ from typing import Final
|
|
3 |
|
4 |
MMLU_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
5 |
"""\
|
6 |
-
You are
|
7 |
|
8 |
Instructions:
|
9 |
-
1.
|
10 |
-
2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
)
|
|
|
3 |
|
4 |
MMLU_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
5 |
"""\
|
6 |
+
You are a highly knowledgeable expert tasked with answering multiple-choice questions across various academic and professional fields. Each question has four options (A, B, C, D). Your goal is to select the single most accurate answer based on factual knowledge.
|
7 |
|
8 |
Instructions:
|
9 |
+
1. Carefully analyze the question and all answer options
|
10 |
+
2. Consider only verified, factual information
|
11 |
+
3. Select the most precise and accurate option
|
12 |
+
4. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
|
13 |
+
"""
|
14 |
+
)
|
15 |
+
|
16 |
+
MMLU_PRO_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
17 |
+
"""\
|
18 |
+
You are a highly knowledgeable expert tasked with answering multiple-choice questions across various academic and professional fields. Each question has ten options (A through J). Your goal is to select the single most accurate answer based on factual knowledge.
|
19 |
+
|
20 |
+
Instructions:
|
21 |
+
1. Carefully analyze the question and all answer options
|
22 |
+
2. Consider only verified, factual information
|
23 |
+
3. Select the most precise and accurate option
|
24 |
+
4. Respond with ONLY the letter (A through J) - no explanations or additional text
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
|
28 |
+
GSM8K_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
29 |
+
"""\
|
30 |
+
You are an expert mathematics tutor. Your task is to solve math word problems by breaking them down into clear, logical steps.
|
31 |
+
|
32 |
+
Instructions:
|
33 |
+
1. Read the problem carefully
|
34 |
+
2. Show your step-by-step reasoning
|
35 |
+
3. Ensure each step is clear and mathematically sound
|
36 |
+
4. End with the final numerical answer
|
37 |
+
5. Format your response as:
|
38 |
+
Let's solve this step by step:
|
39 |
+
1) [First step]
|
40 |
+
2) [Second step]
|
41 |
+
...
|
42 |
+
Therefore, the answer is [number]
|
43 |
+
"""
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
HUMANEVAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
48 |
+
"""\
|
49 |
+
You are an expert Python programmer tasked with implementing Python functions. Your goal is to write clean, efficient, and correct code that meets the specifications.
|
50 |
+
|
51 |
+
Instructions:
|
52 |
+
1. Read the function signature and docstring carefully
|
53 |
+
2. Implement only the function body, not the signature or docstring
|
54 |
+
3. Follow Python best practices and PEP 8 style guidelines
|
55 |
+
4. Write clear, readable code with appropriate variable names
|
56 |
+
5. Handle edge cases and input validation where necessary
|
57 |
+
6. Use type hints and ensure type safety
|
58 |
+
7. Optimize for both readability and performance
|
59 |
+
8. Add comments for complex logic or non-obvious implementations
|
60 |
+
9. Include appropriate error handling with specific exception types
|
61 |
+
10. Consider writing code that would be easy to test
|
62 |
+
11. Return only the implementation code, no additional text
|
63 |
+
|
64 |
+
Example of good implementation:
|
65 |
+
```python
|
66 |
+
# Handle edge case of empty input
|
67 |
+
if not numbers:
|
68 |
+
raise ValueError("Input list cannot be empty")
|
69 |
+
|
70 |
+
# Use descriptive variable names and type hints
|
71 |
+
result: list[int] = sorted(numbers)
|
72 |
+
return result[len(result) // 2] # Return median value
|
73 |
+
```
|
74 |
+
"""
|
75 |
+
)
|
76 |
+
|
77 |
+
MGSM_SYSTEM_PROMPT = textwrap.dedent(
|
78 |
+
"""\
|
79 |
+
You are an expert mathematics tutor who can explain solutions in multiple languages. Your task is to solve math word problems by breaking them down into clear, logical steps.
|
80 |
+
|
81 |
+
Instructions:
|
82 |
+
1. Read the problem carefully
|
83 |
+
2. Show your step-by-step reasoning
|
84 |
+
3. Ensure each step is clear and mathematically sound
|
85 |
+
4. Use appropriate number formatting for the target language (e.g., decimal points vs. commas)
|
86 |
+
5. End with the final numerical answer
|
87 |
+
6. Format your response as:
|
88 |
+
Let's solve this step by step:
|
89 |
+
1) [First step]
|
90 |
+
2) [Second step]
|
91 |
+
...
|
92 |
+
Therefore, the answer is [number]
|
93 |
+
"""
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
IFEVAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
98 |
+
"""\
|
99 |
+
You are a precise instruction follower. Your task is to generate responses that exactly match given requirements and constraints.
|
100 |
+
|
101 |
+
Instructions:
|
102 |
+
1. Read all requirements carefully
|
103 |
+
2. Follow formatting rules exactly
|
104 |
+
3. Meet all length requirements
|
105 |
+
4. Include all required elements
|
106 |
+
5. Avoid forbidden elements
|
107 |
+
6. Provide ONLY the requested output
|
108 |
+
"""
|
109 |
+
)
|
110 |
+
|
111 |
+
BBH_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
112 |
+
"""\
|
113 |
+
You are a highly intelligent expert tasked with solving complex reasoning problems. These problems test various cognitive abilities including logical deduction, causal reasoning, mathematical thinking, and spatial understanding.
|
114 |
+
|
115 |
+
Instructions:
|
116 |
+
1. Read the entire problem carefully, including all given conditions and rules
|
117 |
+
2. Pay attention to the specific type of reasoning required (logical, temporal, spatial, etc.)
|
118 |
+
3. Consider all relationships and constraints mentioned in the problem
|
119 |
+
4. Apply structured thinking to reach a valid conclusion
|
120 |
+
5. Choose the answer that logically follows from the given information
|
121 |
+
6. Respond with ONLY the letter (A, B, C, etc.) or "True"/"False" or "Yes"/"No" - no explanations or additional text
|
122 |
"""
|
123 |
)
|
pyproject.toml
CHANGED
@@ -67,3 +67,7 @@ build-backend = "poetry.core.masonry.api"
|
|
67 |
markers = [
|
68 |
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')"
|
69 |
]
|
|
|
|
|
|
|
|
|
|
67 |
markers = [
|
68 |
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')"
|
69 |
]
|
70 |
+
|
71 |
+
[tool.bandit]
|
72 |
+
exclude_dirs = ["tests"]
|
73 |
+
skips = ["B101"]
|
tests/test_bbh_parser.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from llmdataparser.bbh_parser import BBHDatasetParser, BBHParseEntry
|
4 |
+
|
5 |
+
|
6 |
+
@pytest.fixture
|
7 |
+
def bbh_parser():
|
8 |
+
"""Create a BBH parser instance for testing."""
|
9 |
+
return BBHDatasetParser()
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def loaded_bbh_parser(bbh_parser):
|
14 |
+
"""Create and load a BBH parser instance for testing."""
|
15 |
+
bbh_parser.load(task_name="reasoning_about_colored_objects", split="test")
|
16 |
+
return bbh_parser
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.fixture
|
20 |
+
def sample_row():
|
21 |
+
"""Create a sample BBH data row for testing."""
|
22 |
+
return {
|
23 |
+
"input": "What color is the sky on a clear day?\nA) Blue\nB) Green\nC) Red\nD) Yellow",
|
24 |
+
"target": "(A)",
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def test_bbh_parse_entry_creation_valid():
|
29 |
+
"""Test valid creation of BBHParseEntry."""
|
30 |
+
entry = BBHParseEntry.create(
|
31 |
+
prompt="Test prompt",
|
32 |
+
answer="A",
|
33 |
+
raw_question="Test question",
|
34 |
+
raw_answer="(A)",
|
35 |
+
task_name="reasoning_about_colored_objects",
|
36 |
+
)
|
37 |
+
assert isinstance(entry, BBHParseEntry)
|
38 |
+
assert entry.prompt == "Test prompt"
|
39 |
+
assert entry.answer == "A"
|
40 |
+
assert entry.raw_question == "Test question"
|
41 |
+
assert entry.raw_answer == "(A)"
|
42 |
+
assert entry.task_name == "reasoning_about_colored_objects"
|
43 |
+
|
44 |
+
|
45 |
+
def test_bbh_parser_initialization(bbh_parser):
|
46 |
+
"""Test BBH parser initialization."""
|
47 |
+
assert bbh_parser._data_source == "lukaemon/bbh"
|
48 |
+
assert bbh_parser._default_task == "reasoning_about_colored_objects"
|
49 |
+
assert "boolean_expressions" in bbh_parser._task_names
|
50 |
+
assert "word_sorting" in bbh_parser._task_names
|
51 |
+
assert (
|
52 |
+
bbh_parser.get_huggingface_link
|
53 |
+
== "https://huggingface.co/datasets/lukaemon/bbh"
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def test_load_dataset(loaded_bbh_parser):
|
58 |
+
"""Test loading the dataset."""
|
59 |
+
assert loaded_bbh_parser.raw_data is not None
|
60 |
+
assert loaded_bbh_parser.split_names == ["test"]
|
61 |
+
assert loaded_bbh_parser._current_task == "reasoning_about_colored_objects"
|
62 |
+
|
63 |
+
|
64 |
+
@pytest.mark.integration
|
65 |
+
def test_full_parse_workflow(loaded_bbh_parser):
|
66 |
+
"""Test the complete workflow of loading and parsing data."""
|
67 |
+
# Parse the test split
|
68 |
+
loaded_bbh_parser.parse(split_names="test", force=True)
|
69 |
+
parsed_data = loaded_bbh_parser.get_parsed_data
|
70 |
+
|
71 |
+
# Basic checks
|
72 |
+
assert len(parsed_data) > 0
|
73 |
+
|
74 |
+
# Check first entry structure
|
75 |
+
first_entry = parsed_data[0]
|
76 |
+
assert isinstance(first_entry, BBHParseEntry)
|
77 |
+
assert first_entry.task_name == "reasoning_about_colored_objects"
|
78 |
+
assert first_entry.answer.strip("()").isalpha() # Should be a single letter
|
79 |
+
assert first_entry.prompt.startswith(loaded_bbh_parser._system_prompt)
|
80 |
+
|
81 |
+
|
82 |
+
def test_process_entry(bbh_parser, sample_row):
|
83 |
+
"""Test processing of a single BBH entry."""
|
84 |
+
entry = bbh_parser.process_entry(
|
85 |
+
sample_row, task_name="reasoning_about_colored_objects"
|
86 |
+
)
|
87 |
+
|
88 |
+
assert isinstance(entry, BBHParseEntry)
|
89 |
+
assert entry.answer == "A" # Stripped from "(A)"
|
90 |
+
assert "What color is the sky" in entry.raw_question
|
91 |
+
assert entry.raw_answer == "(A)"
|
92 |
+
assert bbh_parser._system_prompt in entry.prompt
|
93 |
+
assert entry.task_name == "reasoning_about_colored_objects"
|
94 |
+
|
95 |
+
|
96 |
+
@pytest.mark.parametrize("split_name", ["invalid_split", "wrong_split"])
|
97 |
+
def test_parse_with_invalid_split(bbh_parser, split_name):
|
98 |
+
"""Test parsing with invalid split names."""
|
99 |
+
bbh_parser.raw_data = {"train": [], "test": []} # Mock data
|
100 |
+
|
101 |
+
with pytest.raises(
|
102 |
+
ValueError, match=f"Split '{split_name}' not found in the dataset"
|
103 |
+
):
|
104 |
+
bbh_parser.parse(split_name)
|
105 |
+
|
106 |
+
|
107 |
+
def test_parse_without_loaded_data(bbh_parser):
|
108 |
+
"""Test parsing without loading data first."""
|
109 |
+
with pytest.raises(
|
110 |
+
ValueError, match="No data loaded. Please load the dataset first"
|
111 |
+
):
|
112 |
+
bbh_parser.parse()
|
113 |
+
|
114 |
+
|
115 |
+
@pytest.mark.parametrize(
|
116 |
+
"test_case",
|
117 |
+
[
|
118 |
+
{"input": "Test question", "target": "(A)"},
|
119 |
+
{"input": "Test question", "target": "(B)"},
|
120 |
+
{"input": "Test question", "target": "(C)"},
|
121 |
+
],
|
122 |
+
)
|
123 |
+
def test_answer_stripping(bbh_parser, test_case):
|
124 |
+
"""Test stripping of parentheses from answers."""
|
125 |
+
entry = bbh_parser.process_entry(
|
126 |
+
test_case, task_name="reasoning_about_colored_objects"
|
127 |
+
)
|
128 |
+
assert entry.answer == test_case["target"].strip("()")
|
129 |
+
assert entry.raw_answer == test_case["target"]
|
130 |
+
|
131 |
+
|
132 |
+
def test_parser_properties(bbh_parser):
|
133 |
+
"""Test parser property getters."""
|
134 |
+
assert len(bbh_parser.task_names) > 0
|
135 |
+
assert bbh_parser.total_tasks == len(bbh_parser._task_names)
|
136 |
+
assert all(isinstance(task, str) for task in bbh_parser.task_names)
|
137 |
+
|
138 |
+
|
139 |
+
def test_parser_string_representation(loaded_bbh_parser):
|
140 |
+
"""Test string representation of parser."""
|
141 |
+
repr_str = str(loaded_bbh_parser)
|
142 |
+
assert "BBHDatasetParser" in repr_str
|
143 |
+
assert "lukaemon/bbh" in repr_str
|
144 |
+
assert "reasoning_about_colored_objects" in repr_str
|
145 |
+
assert "loaded" in repr_str
|
146 |
+
|
147 |
+
|
148 |
+
@pytest.mark.integration
|
149 |
+
@pytest.mark.parametrize(
|
150 |
+
"task_name", ["boolean_expressions", "causal_judgement", "date_understanding"]
|
151 |
+
)
|
152 |
+
def test_different_tasks_parsing(bbh_parser, task_name):
|
153 |
+
"""Test parsing different tasks of the dataset."""
|
154 |
+
bbh_parser.load(task_name=task_name, split="test")
|
155 |
+
bbh_parser.parse(split_names="test", force=True)
|
156 |
+
parsed_data = bbh_parser.get_parsed_data
|
157 |
+
|
158 |
+
assert len(parsed_data) > 0
|
159 |
+
assert all(entry.task_name == task_name for entry in parsed_data)
|
160 |
+
assert all(isinstance(entry.answer, str) for entry in parsed_data)
|
tests/test_mmlu_parser.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from llmdataparser.mmlu_parser import (
|
4 |
+
BaseMMLUDatasetParser,
|
5 |
+
MMLUParseEntry,
|
6 |
+
MMLUProDatasetParser,
|
7 |
+
MMLUProParseEntry,
|
8 |
+
MMLUReduxDatasetParser,
|
9 |
+
TMMLUPlusDatasetParser,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.fixture
|
14 |
+
def base_parser():
|
15 |
+
"""Create a base MMLU parser instance."""
|
16 |
+
return BaseMMLUDatasetParser()
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.fixture
|
20 |
+
def redux_parser():
|
21 |
+
"""Create a MMLU Redux parser instance."""
|
22 |
+
return MMLUReduxDatasetParser()
|
23 |
+
|
24 |
+
|
25 |
+
@pytest.fixture
|
26 |
+
def tmmlu_parser():
|
27 |
+
"""Create a TMMLU+ parser instance."""
|
28 |
+
return TMMLUPlusDatasetParser()
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.fixture
|
32 |
+
def mmlu_pro_parser():
|
33 |
+
"""Create a MMLU Pro parser instance."""
|
34 |
+
return MMLUProDatasetParser()
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.fixture
|
38 |
+
def sample_mmlu_entries():
|
39 |
+
"""Create sample MMLU dataset entries for testing."""
|
40 |
+
return [
|
41 |
+
{
|
42 |
+
"question": "What is the capital of France?",
|
43 |
+
"choices": ["London", "Paris", "Berlin", "Madrid"],
|
44 |
+
"answer": 1, # Paris
|
45 |
+
"subject": "geography",
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"question": "Which of these is a primary color?",
|
49 |
+
"choices": ["Green", "Purple", "Blue", "Orange"],
|
50 |
+
"answer": 2, # Blue
|
51 |
+
"subject": "art",
|
52 |
+
},
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
@pytest.fixture
|
57 |
+
def sample_mmlu_pro_entries():
|
58 |
+
"""Create sample MMLU Pro dataset entries for testing."""
|
59 |
+
return [
|
60 |
+
{
|
61 |
+
"question": "What is the time complexity of quicksort?",
|
62 |
+
"options": ["O(n)", "O(n log n)", "O(n²)", "O(2ⁿ)", "O(n!)", "O(1)"],
|
63 |
+
"answer": "The average time complexity of quicksort is O(n log n)",
|
64 |
+
"answer_index": 1,
|
65 |
+
"category": "computer_science",
|
66 |
+
}
|
67 |
+
]
|
68 |
+
|
69 |
+
|
70 |
+
def test_mmlu_parse_entry_creation_valid():
|
71 |
+
"""Test valid creation of MMLUParseEntry."""
|
72 |
+
entry = MMLUParseEntry.create(
|
73 |
+
prompt="Test prompt",
|
74 |
+
answer="A",
|
75 |
+
raw_question="Test question",
|
76 |
+
raw_choices=["choice1", "choice2", "choice3", "choice4"],
|
77 |
+
raw_answer="0",
|
78 |
+
task_name="test_task",
|
79 |
+
)
|
80 |
+
assert isinstance(entry, MMLUParseEntry)
|
81 |
+
assert entry.prompt == "Test prompt"
|
82 |
+
assert entry.answer == "A"
|
83 |
+
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
|
84 |
+
assert entry.task_name == "test_task"
|
85 |
+
|
86 |
+
|
87 |
+
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
|
88 |
+
def test_mmlu_parse_entry_creation_invalid(invalid_answer):
|
89 |
+
"""Test invalid answer handling in MMLUParseEntry creation."""
|
90 |
+
with pytest.raises(
|
91 |
+
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
|
92 |
+
):
|
93 |
+
MMLUParseEntry.create(
|
94 |
+
prompt="Test prompt",
|
95 |
+
answer=invalid_answer,
|
96 |
+
raw_question="Test question",
|
97 |
+
raw_choices=["choice1", "choice2", "choice3", "choice4"],
|
98 |
+
raw_answer="4",
|
99 |
+
task_name="test_task",
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def test_process_entry_base(base_parser, sample_mmlu_entries):
|
104 |
+
"""Test processing entries in base MMLU parser."""
|
105 |
+
entry = base_parser.process_entry(sample_mmlu_entries[0], task_name="geography")
|
106 |
+
|
107 |
+
assert isinstance(entry, MMLUParseEntry)
|
108 |
+
assert entry.answer == "B" # Index 1 maps to B
|
109 |
+
assert "A. London" in entry.prompt
|
110 |
+
assert "B. Paris" in entry.prompt
|
111 |
+
assert "C. Berlin" in entry.prompt
|
112 |
+
assert "D. Madrid" in entry.prompt
|
113 |
+
assert entry.raw_question == "What is the capital of France?"
|
114 |
+
assert entry.raw_choices == ["London", "Paris", "Berlin", "Madrid"]
|
115 |
+
assert entry.raw_answer == 1
|
116 |
+
assert entry.task_name == "geography"
|
117 |
+
|
118 |
+
|
119 |
+
def test_mmlu_pro_parse_entry_creation_valid():
|
120 |
+
"""Test valid creation of MMLUProParseEntry."""
|
121 |
+
entry = MMLUProParseEntry.create(
|
122 |
+
prompt="Test prompt",
|
123 |
+
answer="E", # MMLU Pro supports up to J
|
124 |
+
raw_question="Test question",
|
125 |
+
raw_choices=["choice1", "choice2", "choice3", "choice4", "choice5"],
|
126 |
+
raw_answer="4",
|
127 |
+
task_name="test_task",
|
128 |
+
)
|
129 |
+
assert isinstance(entry, MMLUProParseEntry)
|
130 |
+
assert entry.answer == "E"
|
131 |
+
assert len(entry.raw_choices) == 5
|
132 |
+
|
133 |
+
|
134 |
+
def test_process_entry_mmlu_pro(mmlu_pro_parser, sample_mmlu_pro_entries):
|
135 |
+
"""Test processing entries in MMLU Pro parser."""
|
136 |
+
entry = mmlu_pro_parser.process_entry(
|
137 |
+
sample_mmlu_pro_entries[0], task_name="computer_science"
|
138 |
+
)
|
139 |
+
|
140 |
+
assert isinstance(entry, MMLUProParseEntry)
|
141 |
+
assert entry.answer == "B" # Index 1 maps to B
|
142 |
+
assert "O(n log n)" in entry.prompt
|
143 |
+
assert entry.task_name == "computer_science"
|
144 |
+
assert len(entry.raw_choices) == 6
|
145 |
+
|
146 |
+
|
147 |
+
def test_tmmlu_process_entry(tmmlu_parser):
|
148 |
+
"""Test processing entries in TMMLU+ parser."""
|
149 |
+
test_row = {
|
150 |
+
"question": "什麼是台灣最高的山峰?",
|
151 |
+
"A": "玉山",
|
152 |
+
"B": "阿里山",
|
153 |
+
"C": "合歡山",
|
154 |
+
"D": "雪山",
|
155 |
+
"answer": "A",
|
156 |
+
"subject": "geography_of_taiwan",
|
157 |
+
}
|
158 |
+
|
159 |
+
entry = tmmlu_parser.process_entry(test_row, task_name="geography_of_taiwan")
|
160 |
+
assert isinstance(entry, MMLUParseEntry)
|
161 |
+
assert entry.answer == "A"
|
162 |
+
assert entry.raw_choices == ["玉山", "阿里山", "合歡山", "雪山"]
|
163 |
+
assert entry.task_name == "geography_of_taiwan"
|
164 |
+
|
165 |
+
|
166 |
+
@pytest.mark.parametrize(
|
167 |
+
"parser_fixture,expected_tasks,expected_source",
|
168 |
+
[
|
169 |
+
("base_parser", 57, "cais/mmlu"),
|
170 |
+
("redux_parser", 30, "edinburgh-dawg/mmlu-redux"),
|
171 |
+
("tmmlu_parser", 66, "ikala/tmmluplus"),
|
172 |
+
("mmlu_pro_parser", 14, "TIGER-Lab/MMLU-Pro"),
|
173 |
+
],
|
174 |
+
)
|
175 |
+
def test_parser_initialization(
|
176 |
+
request, parser_fixture, expected_tasks, expected_source
|
177 |
+
):
|
178 |
+
"""Test initialization of different MMLU parser variants."""
|
179 |
+
parser = request.getfixturevalue(parser_fixture)
|
180 |
+
assert len(parser.task_names) == expected_tasks
|
181 |
+
assert parser._data_source == expected_source
|
182 |
+
assert (
|
183 |
+
parser.get_huggingface_link
|
184 |
+
== f"https://huggingface.co/datasets/{expected_source}"
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
@pytest.mark.integration
|
189 |
+
def test_load_dataset(base_parser):
|
190 |
+
"""Test loading the MMLU dataset."""
|
191 |
+
base_parser.load(task_name="anatomy", split="test")
|
192 |
+
assert base_parser.raw_data is not None
|
193 |
+
assert base_parser.split_names == ["test"]
|
194 |
+
assert base_parser._current_task == "anatomy"
|
195 |
+
|
196 |
+
|
197 |
+
def test_parser_string_representation(base_parser):
|
198 |
+
"""Test string representation of MMLU parser."""
|
199 |
+
repr_str = str(base_parser)
|
200 |
+
assert "MMLUDatasetParser" in repr_str
|
201 |
+
assert "cais/mmlu" in repr_str
|
202 |
+
assert "not loaded" in repr_str
|
203 |
+
|
204 |
+
|
205 |
+
@pytest.mark.integration
|
206 |
+
def test_different_splits_parsing(base_parser):
|
207 |
+
"""Test parsing different splits of the dataset."""
|
208 |
+
# Load and parse test split
|
209 |
+
base_parser.load(task_name="anatomy", split="test")
|
210 |
+
base_parser.parse(split_names="test", force=True)
|
211 |
+
test_count = len(base_parser.get_parsed_data)
|
212 |
+
|
213 |
+
# Load and parse validation split
|
214 |
+
base_parser.load(task_name="anatomy", split="validation")
|
215 |
+
base_parser.parse(split_names="validation", force=True)
|
216 |
+
val_count = len(base_parser.get_parsed_data)
|
217 |
+
|
218 |
+
assert test_count > 0
|
219 |
+
assert val_count > 0
|
220 |
+
assert test_count != val_count
|