File size: 17,951 Bytes
b90e4c6
3403ad5
 
 
b90e4c6
3403ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b90e4c6
 
3403ad5
b90e4c6
5bc3237
 
 
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b90e4c6
3403ad5
 
 
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
 
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
 
 
 
b90e4c6
3403ad5
b90e4c6
3403ad5
 
 
 
 
b90e4c6
3403ad5
b90e4c6
3403ad5
 
 
 
 
 
b90e4c6
3403ad5
 
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
b90e4c6
3403ad5
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
---
language:
- en
license: llama3
library_name: transformers
tags:
- mathematics
datasets:
- hkust-nlp/dart-math-hard
metrics:
- accuracy
pipeline_tag: text-generation
base_model: meta-llama/Meta-Llama-3-70B
model-index:
  - name: dart-math-llama3-70b-prop2diff
    results:
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: hendrycks/competition_math
          name: MATH
          split: test
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 56.1
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: openai/gsm8k
          name: GSM8K
          config: main
          split: test
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 89.6
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: college-math
          name: CollegeMath
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 37.9
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: deepmind-mathematics
          name: DeepMind-Mathematics
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 64.1
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: Hothan/OlympiadBench
          name: OlympiadBench-OE_TO_maths_en_COMP
          config: OE_TO_maths_en_COMP
          split: train
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 20.0
      - task:
          type: text-generation
          name: Mathematical Problem-Solving
        dataset:
          type: TIGER-Lab/TheoremQA
          name: TheoremQA
          split: test
        metrics:
          - type: accuracy
            name: Pass@1 (0-shot CoT)
            value: 28.2
---

# DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving

📝 [Paper@arXiv](https://arxiv.org/abs/2407.13690) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/hkust-nlp/dart-math-665704599b35de59f8fdf6c1) | 🐱 [Code@GitHub](https://github.com/hkust-nlp/dart-math)

🐦 [Thread@X(Twitter)](https://x.com/tongyx361/status/1811413243350454455) | 🐶 [中文博客@知乎](https://zhuanlan.zhihu.com/p/708371895) | 📊 [Leaderboard@PapersWithCode](https://paperswithcode.com/paper/dart-math-difficulty-aware-rejection-tuning#results) | 📑 [BibTeX](https://github.com/hkust-nlp/dart-math?tab=readme-ov-file#citation)

## Models: `DART-Math`

`DART-Math` models achieve performance **superior or competitive to previous SOTAs** on 2 in-domain and 4 challenging out-of-domain mathematical reasoning benchmarks, despite using **much smaller datasets** and **no proprietary model like GPT-4**.

| Model                                                                                                  | [MATH](https://huggingface.co/datasets/hendrycks/competition_math) | [GSM8K](https://huggingface.co/datasets/gsm8k) | [College](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/mwpbench/college-math-test.jsonl) | [DM](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/deepmind-mathematics.json) | [Olympiad](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/olympiadbench/OE_TO_maths_en_COMP.json) | [Theorem](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/theoremqa.json) |      AVG |
| :----------------------------------------------------------------------------------------------------- | -----------------------------------------------------------------: | ---------------------------------------------: | -----------------------------------------------------------------------------------------------------------: | -----------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------: | -----------------------------------------------------------------------------------------: | -------: |
| GPT-4 (0314)                                                                                           |                           [52.6](https://arxiv.org/abs/2403.04706) |       [94.7](https://arxiv.org/abs/2403.04706) |                                                                     [24.4](https://arxiv.org/abs/2403.02884) |                                                                                               -- |                                                                                                                  -- |                                                                                         -- |       -- |
| Llama-3-70B-MetaMath                                                                                   |                                                               44.9 |                                           88.0 |                                                                                                         31.9 |                                                                                             53.2 |                                                                                                                11.6 |                                                                                       21.9 |     41.9 |
| [`DART-Math-Llama-3-70B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-llama3-70b-uniform)     |                                                               54.9 |                                       **90.4** |                                                                                                     **38.5** |                                                                                         **64.1** |                                                                                                                19.1 |                                                                                       27.4 |     49.1 |
| [`DART-Math-Llama-3-70B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-llama3-70b-prop2diff) |                                                           **56.1** |                                           89.6 |                                                                                                         37.9 |                                                                                         **64.1** |                                                                                                            **20.0** |                                                                                   **28.2** | **49.3** |
| DeepSeekMath-7B-MetaMath                                                                               |                                                               43.7 |                                           81.8 |                                                                                                         33.7 |                                                                                             53.0 |                                                                                                                13.6 |                                                                                       23.2 |     41.5 |
| [DeepSeekMath-7B-RL](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl)                           |                                                               53.1 |                                           88.4 |                                                                                                         41.3 |                                                                                             58.3 |                                                                                                                18.7 |                                                                                       35.9 |     49.3 |
| [`DART-Math-DSMath-7B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-dsmath-7b-uniform)        |                                                               52.9 |                                       **88.2** |                                                                                                         40.1 |                                                                                             60.2 |                                                                                                                21.3 |                                                                                   **32.5** |     49.2 |
| [`DART-Math-DSMath-7B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-dsmath-7b-prop2diff)    |                                                           **53.6** |                                           86.8 |                                                                                                     **40.7** |                                                                                         **61.6** |                                                                                                            **21.7** |                                                                                       32.2 | **49.4** |
| Mistral-7B-MetaMath                                                                                    |                                                               29.8 |                                           76.5 |                                                                                                         19.3 |                                                                                             28.0 |                                                                                                                 5.9 |                                                                                       14.0 |     28.9 |
| [`DART-Math-Mistral-7B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-mistral-7b-uniform)      |                                                               43.5 |                                       **82.6** |                                                                                                         26.9 |                                                                                             42.0 |                                                                                                                13.2 |                                                                                       16.4 |     27.4 |
| [`DART-Math-Mistral-7B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-mistral-7b-prop2diff)  |                                                           **45.5** |                                           81.1 |                                                                                                     **29.4** |                                                                                         **45.1** |                                                                                                            **14.7** |                                                                                   **17.0** | **38.8** |
| Llama-3-8B-MetaMath                                                                                    |                                                               32.5 |                                           77.3 |                                                                                                         20.6 |                                                                                             35.0 |                                                                                                                 5.5 |                                                                                       13.8 |     30.8 |
| [`DART-Math-Llama-3-8B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-llama3-8b-uniform)       |                                                               45.3 |                                       **82.5** |                                                                                                         27.1 |                                                                                         **48.2** |                                                                                                                13.6 |                                                                                       15.4 |     38.7 |
| [`DART-Math-Llama-3-8B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-llama3-8b-prop2diff)   |                                                           **46.6** |                                           81.1 |                                                                                                     **28.8** |                                                                                             48.0 |                                                                                                            **14.5** |                                                                                   **19.4** | **39.7** |

***Abbreviations**: College (CollegeMath), DM (DeepMind Mathematics), Olympiad (OlympiadBench-Math), Theorem (TheoremQA).
**Bold** means the best score by SFT on the respective base model here.
To reproduce our results, please refer to [the `DART-Math` GitHub repository](https://github.com/hkust-nlp/dart-math).*

## Prompt Template

All the `DART-Math` models use the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) prompt template:

```

Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction:\n{query}\n\n### Response:\n

```

## Training Dataset

We construct our traning datasets by applying **Difficulty-Aware Rejection Sampling** (`DARS`) to the **MATH and GSM8K** training sets.

`DARS` tackle **severe biases towards easy queries, with frequent failures to generate any correct response for the most challenging queries**, in previous datasets.

These biases are primarily caused by vanilla rejection sampling, where **the same number of responses is
sampled for each query**, yet the likelihood of obtaining correct responses for difficult queries is significantly lower, sometimes even zero.

Please refer to [`DART-Math-Hard`](https://huggingface.co/datasets/hkust-nlp/dart-math-hard) / [`DART-Math-Uniform`](https://huggingface.co/datasets/hkust-nlp/dart-math-uniform) for more details.

## Training Setup

We perform standard instruction tuning to several base models including Llama3-8B & Mistral-7B & Llama3-70B as representatives of general models and DeepSeekMath-
7B as the representative of math-specialized model
on our synthetic datasets [`DART-Math-Hard`](https://huggingface.co/datasets/hkust-nlp/dart-math-hard) & [`DART-Math-Uniform`](https://huggingface.co/datasets/hkust-nlp/dart-math-uniform),
leading to `DART-Math (Prop2Diff)` & `DART-Math (Uniform)` respectively.

For simplicity, we keep most hyper-parameters the same across different models and datasets:

- Model max length (of [packed](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing) sequence): 4096
- Batch size: 64
- Warm-up ratio: 0.03
- Learning rate scheduler: cosine
- Prompt template: [Alpaca](https://github.com/tatsu-lab/stanford_alpaca)

Several other key hyper-parameters are tuned as follow:

| Base Model      | Max. L.R. | # of Epochs | # of Grad. Acc. Steps | # of A100 GPUs |
|:--------------- | ---------:| -----------:| ---------------------:| --------------:|
| Mistral-7B      |    `1e-5` |           3 |                     1 |              8 |
| Llama3-8B       |    `5e-5` |           1 |                     2 |              8 |
| Llama3-70B      |    `2e-5` |           1 |                     1 |             32 |
| DeepSeekMath-7B |    `5e-5` |           3 |                     1 |              8 |

- For **maximum learning rate**, we determine the values by **searching** through `1e-6,5e-6,1e-5,2e-5,5e-5,1e-4` according to the MATH performance after training on MMIQC for 1 epoch, except for Llama3-70B that is so expensive to search for that we derive from Llama3-8B’s learning rate in analogy to the relationship of (per-training) learning rates between [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf) and [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-hf) (\~2:1).
- For **Llama3** models, preliminary experiments indicate that **training for 1 epoch consistently outperforms 3 epochs**.

Please refer to [Appendix A.1 of our paper](https://tongyx361.github.io/assets/dart-math/paper-dart-math.pdf) for more details.

## Other Details

- For Mistral-7B-based models, we disable `sliding_window` by default following [the newest Mistral-7B-Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/config.json) (Flash Attention 2 does not support `sliding_window` and XFormer backend in vLLM has throughput \~10% lower in our experiments.)

## Citation

If you find our data, model or code useful for your work, please kindly cite [our paper](https://arxiv.org/abs/2407.13690):

```latex
@article{tong2024dartmath,
  title={DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving},
  author={Yuxuan Tong and Xiwen Zhang and Rui Wang and Ruidong Wu and Junxian He},
  year={2024},
  eprint={2407.13690},
  archivePrefix={arXiv},
  primaryClass={cs.CL},
  url={https://arxiv.org/abs/2407.13690},
}
```