JeffYang52415 commited on
Commit
b65e855
·
unverified ·
1 Parent(s): 289c905

feat: add math parser

Browse files
.pre-commit-config.yaml CHANGED
@@ -65,14 +65,10 @@ repos:
65
  - id: prettier
66
  types_or: [markdown, yaml]
67
  - repo: https://github.com/astral-sh/ruff-pre-commit
68
- # Ruff version.
69
  rev: v0.4.4
70
  hooks:
71
- # Run the linter.
72
  - id: ruff
73
  args: [--fix]
74
- # Run the formatter.
75
- - id: ruff-format
76
  - repo: https://github.com/kynan/nbstripout
77
  rev: 0.5.0 # use the latest version
78
  hooks:
 
65
  - id: prettier
66
  types_or: [markdown, yaml]
67
  - repo: https://github.com/astral-sh/ruff-pre-commit
 
68
  rev: v0.4.4
69
  hooks:
 
70
  - id: ruff
71
  args: [--fix]
 
 
72
  - repo: https://github.com/kynan/nbstripout
73
  rev: 0.5.0 # use the latest version
74
  hooks:
llmdataparser/math_parser.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, ClassVar
3
+
4
+ from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
5
+
6
+
7
+ @dataclass(frozen=True, kw_only=True, slots=True)
8
+ class MATHParseEntry(HuggingFaceParseEntry):
9
+ """Custom entry class for MATH dataset, with fields specific to this dataset parser."""
10
+
11
+ level: str
12
+ task_name: str
13
+ solution: str
14
+
15
+ @classmethod
16
+ def create(
17
+ cls,
18
+ prompt: str,
19
+ answer: str,
20
+ raw_question: str,
21
+ raw_answer: str,
22
+ level: str,
23
+ task_name: str,
24
+ solution: str,
25
+ ) -> "MATHParseEntry":
26
+ return cls(
27
+ prompt=prompt,
28
+ answer=answer,
29
+ raw_question=raw_question,
30
+ raw_answer=raw_answer,
31
+ level=level,
32
+ task_name=task_name,
33
+ solution=solution,
34
+ )
35
+
36
+
37
+ class MATHDatasetParser(HuggingFaceDatasetParser[MATHParseEntry]):
38
+ """Parser for the MATH dataset."""
39
+
40
+ _data_source: ClassVar[str] = "lighteval/MATH"
41
+ _task_names: ClassVar[list[str]] = [
42
+ "algebra",
43
+ "geometry",
44
+ "calculus",
45
+ "prealgebra",
46
+ "intermediate_algebra",
47
+ "number_theory",
48
+ "precalculus",
49
+ "all",
50
+ ]
51
+ _default_task: ClassVar[str] = "all"
52
+ _default_system_prompt: ClassVar[
53
+ str
54
+ ] = "Solve the following mathematics problem step by step:"
55
+ _valid_levels: ClassVar[set[str]] = {
56
+ f"Level {i}" for i in range(1, 6)
57
+ } # Levels 1-5 are valid
58
+
59
+ def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
60
+ """Get the task name from the data entry or fall back to current task."""
61
+ entry_type = data_entry.get("type")
62
+ if entry_type and (entry_type in self._task_names):
63
+ return entry_type
64
+ return self._current_task or self._default_task
65
+
66
+ def process_entry(
67
+ self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
68
+ ) -> MATHParseEntry:
69
+ """Process a single MATH dataset entry."""
70
+ task = task_name or self._get_current_task(row)
71
+
72
+ # Validate and normalize level
73
+ level = row.get("level")
74
+ if level not in self._valid_levels:
75
+ level = "Unknown"
76
+
77
+ return MATHParseEntry.create(
78
+ prompt=f"{self._system_prompt}\n{row['problem']}",
79
+ answer=row["solution"],
80
+ raw_question=row["problem"],
81
+ raw_answer=row["solution"],
82
+ level=level,
83
+ task_name=task,
84
+ solution=row["solution"],
85
+ )
86
+
87
+
88
+ if __name__ == "__main__":
89
+ # Example usage of MATH parser
90
+ parser = MATHDatasetParser()
91
+
92
+ # Load the dataset
93
+ parser.load()
94
+
95
+ # Parse all splits
96
+ parser.parse()
97
+
98
+ # Get parsed data
99
+ parsed_data = parser.get_parsed_data
100
+
101
+ # Print example entry
102
+ if parsed_data:
103
+ example = parsed_data[0]
104
+ print("\nExample parsed entry:")
105
+ print(f"Task: {example.task_name}")
106
+ print(f"Level: {example.level}")
107
+ print(f"Question: {example.raw_question}")
108
+ print(f"Solution: {example.solution}")
pyproject.toml CHANGED
@@ -49,11 +49,8 @@ profile = "black"
49
  line_length = 88
50
  known_first_party = ["llmdataparser"]
51
 
52
-
53
  [tool.ruff]
54
  line-length = 88
55
- select = ["E", "F"] # or specify checks explicitly without E501
56
- ignore = ["E501"]
57
 
58
  [tool.ruff.lint]
59
  ignore = ["E501"]
 
49
  line_length = 88
50
  known_first_party = ["llmdataparser"]
51
 
 
52
  [tool.ruff]
53
  line-length = 88
 
 
54
 
55
  [tool.ruff.lint]
56
  ignore = ["E501"]
tests/test_math_parser.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from llmdataparser.math_parser import MATHDatasetParser, MATHParseEntry
4
+
5
+
6
+ @pytest.fixture
7
+ def math_parser():
8
+ """Create a MATH parser instance for testing."""
9
+ return MATHDatasetParser()
10
+
11
+
12
+ @pytest.fixture
13
+ def loaded_math_parser(math_parser):
14
+ """Create and load a MATH parser instance with test split."""
15
+ math_parser.load(task_name="algebra", split="test")
16
+ return math_parser
17
+
18
+
19
+ @pytest.fixture
20
+ def sample_math_entries():
21
+ """Create sample MATH dataset entries for testing."""
22
+ return [
23
+ {
24
+ "problem": "Solve for x: 2x + 4 = 10",
25
+ "level": "Level 3",
26
+ "solution": "Let's solve step by step:\n1) Subtract 4 from both sides: 2x = 6\n2) Divide both sides by 2\n\nTherefore, x = 3",
27
+ "type": "algebra",
28
+ },
29
+ {
30
+ "problem": "Find the area of a circle with radius 5 units.",
31
+ "level": "Level 2",
32
+ "solution": "Area = πr²\nArea = π(5)²\nArea = 25π square units",
33
+ "type": "geometry",
34
+ },
35
+ {
36
+ "problem": "What is the limit of (x²-1)/(x-1) as x approaches 1?",
37
+ "level": "Level 4",
38
+ "solution": "Using L'Hôpital's rule:\nlim(x→1) (x²-1)/(x-1) = lim(x→1) (2x)/(1) = 2",
39
+ "type": "calculus",
40
+ },
41
+ ]
42
+
43
+
44
+ def test_math_parse_entry_creation_valid():
45
+ """Test valid creation of MATHParseEntry with all fields."""
46
+ entry = MATHParseEntry.create(
47
+ prompt="Test prompt",
48
+ answer="Test answer",
49
+ raw_question="Test question",
50
+ raw_answer="Test solution",
51
+ level="Level 5",
52
+ task_name="algebra",
53
+ solution="Test solution",
54
+ )
55
+
56
+ assert isinstance(entry, MATHParseEntry)
57
+ assert entry.prompt == "Test prompt"
58
+ assert entry.answer == "Test answer"
59
+ assert entry.raw_question == "Test question"
60
+ assert entry.raw_answer == "Test solution"
61
+ assert entry.level == "Level 5"
62
+ assert entry.task_name == "algebra"
63
+ assert entry.solution == "Test solution"
64
+
65
+
66
+ @pytest.mark.parametrize(
67
+ "test_case",
68
+ [
69
+ {
70
+ "problem": "Solve for x: 2x + 4 = 10",
71
+ "level": "Level 3",
72
+ "solution": "x = 3",
73
+ "type": "algebra",
74
+ },
75
+ {
76
+ "problem": "Find the derivative of f(x) = x²",
77
+ "level": "Level 4",
78
+ "solution": "f'(x) = 2x",
79
+ "type": "calculus",
80
+ },
81
+ ],
82
+ )
83
+ def test_process_entry(math_parser, test_case):
84
+ """Test processing different types of MATH entries."""
85
+ entry = math_parser.process_entry(test_case, task_name=test_case["type"])
86
+
87
+ assert isinstance(entry, MATHParseEntry)
88
+ assert (
89
+ entry.prompt == f"{math_parser._default_system_prompt}\n{test_case['problem']}"
90
+ )
91
+ assert entry.answer == test_case["solution"]
92
+ assert entry.raw_question == test_case["problem"]
93
+ assert entry.raw_answer == test_case["solution"]
94
+ assert entry.level == test_case["level"]
95
+ assert entry.task_name == test_case["type"]
96
+ assert entry.solution == test_case["solution"]
97
+
98
+
99
+ def test_math_parser_initialization(math_parser):
100
+ """Test MATH parser initialization and properties."""
101
+ assert isinstance(math_parser.task_names, list)
102
+ assert len(math_parser.task_names) == 8
103
+ assert math_parser._data_source == "lighteval/MATH"
104
+ assert math_parser._default_task == "all"
105
+ assert "algebra" in math_parser.task_names
106
+ assert "geometry" in math_parser.task_names
107
+ assert (
108
+ math_parser.get_huggingface_link
109
+ == "https://huggingface.co/datasets/lighteval/MATH"
110
+ )
111
+ assert "mathematics problem" in math_parser._default_system_prompt.lower()
112
+
113
+
114
+ def test_get_current_task(math_parser):
115
+ """Test task name resolution in different scenarios."""
116
+ # Test with valid type in data entry
117
+ test_row_with_type = {"type": "algebra"}
118
+ assert math_parser._get_current_task(test_row_with_type) == "algebra"
119
+
120
+ # Test without type in data entry
121
+ test_row_without_type = {}
122
+ math_parser._current_task = "geometry"
123
+ assert math_parser._get_current_task(test_row_without_type) == "geometry"
124
+
125
+ # Test with invalid type - should return current task
126
+ test_row_invalid_type = {"type": "invalid_type"}
127
+ math_parser._current_task = "algebra"
128
+ assert math_parser._get_current_task(test_row_invalid_type) == "algebra"
129
+
130
+
131
+ def test_valid_levels(math_parser):
132
+ """Test handling of valid level values."""
133
+ for i in range(1, 6):
134
+ test_row = {
135
+ "problem": "Test problem",
136
+ "level": f"Level {i}",
137
+ "solution": "Test solution",
138
+ "type": "algebra",
139
+ }
140
+ entry = math_parser.process_entry(test_row, task_name="algebra")
141
+ assert entry.level == f"Level {i}"
142
+
143
+
144
+ @pytest.mark.parametrize(
145
+ "invalid_level",
146
+ [
147
+ "Level 0", # Too low
148
+ "Level 6", # Too high
149
+ "Invalid", # Wrong format
150
+ None, # Missing
151
+ "", # Empty
152
+ "level 1", # Wrong capitalization
153
+ ],
154
+ )
155
+ def test_invalid_level_handling(math_parser, invalid_level):
156
+ """Test handling of invalid level values."""
157
+ test_row = {
158
+ "problem": "Test problem",
159
+ "level": invalid_level,
160
+ "solution": "Test solution",
161
+ "type": "algebra",
162
+ }
163
+
164
+ entry = math_parser.process_entry(test_row, task_name="algebra")
165
+ assert entry.level == "Unknown"
166
+
167
+
168
+ @pytest.mark.integration
169
+ def test_load_dataset(loaded_math_parser):
170
+ """Test loading the MATH dataset."""
171
+ assert loaded_math_parser.raw_data is not None
172
+ assert loaded_math_parser.split_names == ["test"]
173
+ assert loaded_math_parser._current_task == "algebra"
174
+
175
+
176
+ def test_parser_string_representation(loaded_math_parser):
177
+ """Test string representation of MATH parser."""
178
+ repr_str = str(loaded_math_parser)
179
+ assert "MATHDatasetParser" in repr_str
180
+ assert "lighteval/MATH" in repr_str
181
+ assert "algebra" in repr_str
182
+ assert "loaded" in repr_str
183
+
184
+
185
+ @pytest.mark.integration
186
+ def test_different_splits_parsing(math_parser):
187
+ """Test parsing different splits of the dataset."""
188
+ # Load and parse test split
189
+ math_parser.load(task_name="algebra", split="test")
190
+ math_parser.parse(split_names="test", force=True)
191
+ test_count = len(math_parser.get_parsed_data)
192
+
193
+ # Load and parse train split
194
+ math_parser.load(task_name="algebra", split="train")
195
+ math_parser.parse(split_names="train", force=True)
196
+ train_count = len(math_parser.get_parsed_data)
197
+
198
+ assert test_count > 0
199
+ assert train_count > 0
200
+ assert train_count != test_count