DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving
📝 Paper@arXiv | 🤗 Datasets&Models@HF | 🐱 Code@GitHub
🐦 Thread@X(Twitter) | 🐶 中文博客@知乎 | 📊 Leaderboard@PapersWithCode | 📑 BibTeX
Models: DART-Math
DART-Math
models achieve performance superior or competitive to previous SOTAs on 2 in-domain and 4 challenging out-of-domain mathematical reasoning benchmarks, despite using much smaller datasets and no proprietary model like GPT-4.
Model | MATH | GSM8K | College | DM | Olympiad | Theorem | AVG |
---|---|---|---|---|---|---|---|
GPT-4 (0314) | 52.6 | 94.7 | 24.4 | -- | -- | -- | -- |
Llama-3-70B-MetaMath | 44.9 | 88.0 | 31.9 | 53.2 | 11.6 | 21.9 | 41.9 |
DART-Math-Llama-3-70B (Uniform) |
54.9 | 90.4 | 38.5 | 64.1 | 19.1 | 27.4 | 49.1 |
DART-Math-Llama-3-70B (Prop2Diff) |
56.1 | 89.6 | 37.9 | 64.1 | 20.0 | 28.2 | 49.3 |
DeepSeekMath-7B-MetaMath | 43.7 | 81.8 | 33.7 | 53.0 | 13.6 | 23.2 | 41.5 |
DeepSeekMath-7B-RL | 53.1 | 88.4 | 41.3 | 58.3 | 18.7 | 35.9 | 49.3 |
DART-Math-DSMath-7B (Uniform) |
52.9 | 88.2 | 40.1 | 60.2 | 21.3 | 32.5 | 49.2 |
DART-Math-DSMath-7B (Prop2Diff) |
53.6 | 86.8 | 40.7 | 61.6 | 21.7 | 32.2 | 49.4 |
Mistral-7B-MetaMath | 29.8 | 76.5 | 19.3 | 28.0 | 5.9 | 14.0 | 28.9 |
DART-Math-Mistral-7B (Uniform) |
43.5 | 82.6 | 26.9 | 42.0 | 13.2 | 16.4 | 27.4 |
DART-Math-Mistral-7B (Prop2Diff) |
45.5 | 81.1 | 29.4 | 45.1 | 14.7 | 17.0 | 38.8 |
Llama-3-8B-MetaMath | 32.5 | 77.3 | 20.6 | 35.0 | 5.5 | 13.8 | 30.8 |
DART-Math-Llama-3-8B (Uniform) |
45.3 | 82.5 | 27.1 | 48.2 | 13.6 | 15.4 | 38.7 |
DART-Math-Llama-3-8B (Prop2Diff) |
46.6 | 81.1 | 28.8 | 48.0 | 14.5 | 19.4 | 39.7 |
Abbreviations: College (CollegeMath), DM (DeepMind Mathematics), Olympiad (OlympiadBench-Math), Theorem (TheoremQA).
Bold means the best score by SFT on the respective base model here.
To reproduce our results, please refer to the DART-Math
GitHub repository.
Prompt Template
All the DART-Math
models use the Alpaca prompt template:
Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction:\n{query}\n\n### Response:\n
Training Dataset
We construct our traning datasets by applying Difficulty-Aware Rejection Sampling (DARS
) to the MATH and GSM8K training sets.
DARS
tackle severe biases towards easy queries, with frequent failures to generate any correct response for the most challenging queries, in previous datasets.
These biases are primarily caused by vanilla rejection sampling, where the same number of responses is sampled for each query, yet the likelihood of obtaining correct responses for difficult queries is significantly lower, sometimes even zero.
Please refer to DART-Math-Hard
/ DART-Math-Uniform
for more details.
Training Setup
We perform standard instruction tuning to several base models including Llama3-8B & Mistral-7B & Llama3-70B as representatives of general models and DeepSeekMath-
7B as the representative of math-specialized model
on our synthetic datasets DART-Math-Hard
& DART-Math-Uniform
,
leading to DART-Math (Prop2Diff)
& DART-Math (Uniform)
respectively.
For simplicity, we keep most hyper-parameters the same across different models and datasets:
- Model max length (of packed sequence): 4096
- Batch size: 64
- Warm-up ratio: 0.03
- Learning rate scheduler: cosine
- Prompt template: Alpaca
Several other key hyper-parameters are tuned as follow:
Base Model | Max. L.R. | # of Epochs | # of Grad. Acc. Steps | # of A100 GPUs |
---|---|---|---|---|
Mistral-7B | 1e-5 |
3 | 1 | 8 |
Llama3-8B | 5e-5 |
1 | 2 | 8 |
Llama3-70B | 2e-5 |
1 | 1 | 32 |
DeepSeekMath-7B | 5e-5 |
3 | 1 | 8 |
- For maximum learning rate, we determine the values by searching through
1e-6,5e-6,1e-5,2e-5,5e-5,1e-4
according to the MATH performance after training on MMIQC for 1 epoch, except for Llama3-70B that is so expensive to search for that we derive from Llama3-8B’s learning rate in analogy to the relationship of (per-training) learning rates between Llama2-7B and Llama2-70B (~2:1). - For Llama3 models, preliminary experiments indicate that training for 1 epoch consistently outperforms 3 epochs.
Please refer to Appendix A.1 of our paper for more details.
Other Details
- For Mistral-7B-based models, we disable
sliding_window
by default following the newest Mistral-7B-Instruct (Flash Attention 2 does not supportsliding_window
and XFormer backend in vLLM has throughput ~10% lower in our experiments.)
Citation
If you find our data, model or code useful for your work, please kindly cite our paper:
@article{tong2024dartmath,
title={DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving},
author={Yuxuan Tong and Xiwen Zhang and Rui Wang and Ruidong Wu and Junxian He},
year={2024},
eprint={2407.13690},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2407.13690},
}
- Downloads last month
- 577
Model tree for hkust-nlp/dart-math-llama3-8b-prop2diff
Base model
meta-llama/Meta-Llama-3-8BDataset used to train hkust-nlp/dart-math-llama3-8b-prop2diff
Collection including hkust-nlp/dart-math-llama3-8b-prop2diff
Evaluation results
- Pass@1 (0-shot CoT) on MATHtest set self-reported46.600
- Pass@1 (0-shot CoT) on GSM8Ktest set self-reported81.100
- Pass@1 (0-shot CoT) on CollegeMathself-reported28.800
- Pass@1 (0-shot CoT) on DeepMind-Mathematicsself-reported48.000
- Pass@1 (0-shot CoT) on OlympiadBench-OE_TO_maths_en_COMPself-reported14.500
- Pass@1 (0-shot CoT) on TheoremQAtest set self-reported19.400