hanbin commited on
Commit
f569aa3
•
1 Parent(s): 39b0485

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ figures/prm.gif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,93 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Eurus-2-7B-PRIME
5
+
6
+ ## Links
7
+
8
+ - 📜 [Blog]()
9
+ - 🤗 [PRIME Collection](https://huggingface.co/PRIME-RL)
10
+ - 🤗 [RL Data]()
11
+
12
+ ## Introduction
13
+
14
+ Eurus-2-7B-PRIME is trained using **PRIME** (**P**rocess **R**einforcement through **IM**plicit r**E**ward) method, which effectively incorporates and updates reward models in reinforcement learning. It starts with [Eurus-2-7B-SFT](https://huggingface.co/PRIME-RL/Eurus-2-7B-SFT) and trains on [Eurus-2-RL-Data]().
15
+
16
+ <img src="./figures/prm.gif" alt="prm" style="zoom: 33%;" />
17
+
18
+ As shown in the animation above, in PRIME, the policy model and PRM are both initialized with the SFT model. For each RL iteration, the policy model first generates rollouts. Then, the [implicit PRM](https://arxiv.org/abs/2412.01981) and outcome verifier score the rollouts, and the implicit PRM get updated on the rollouts with outcome reward. Finally, the outcome reward \\(r_o\\) and process reward \\(r_p\\) are combined and used to update the policy model.
19
+
20
+ The PRIME implementation pseudocode is as follows:
21
+
22
+ <img src="./figures/prime-algo.jpg" alt="prime-algo" style="zoom: 33%;" />
23
+
24
+ The algorithm flow includes:
25
+
26
+ 1. **Prompt filtering** based on policy model performance, only preserving those on which the policy model \\(\pi_\theta\\) achieves a accuracy between 0.2 and 0.8.
27
+ 2. **Calculate implicit process reward** \\(r^t\\).
28
+ 3. **Update Implicit PRM** \\(\pi_\psi\\) based on predicted implicit process reward \\(r^t\\) and ground truth outcome label \\(r\\).
29
+ 4. **Advantage estimation with RLOO.** Specifically, we first calculate the return of outcome rewards and implicit process rewards separately:
30
+
31
+ - For ground truth outcome rewards, we directly adopt RLOO without any modification.
32
+
33
+ - For implicit process rewards, we perform a three-step process to calculate return: (1) Use the averaged implicit process rewards to calculate the leave-one-out baseline (2) Normalize the process reward at step \\(t\\) by subtracting the baseline; (3) Calculate the discounted return for each response.
34
+
35
+ Finally, advantage is set to the combination of both returns.
36
+
37
+ ​ 5. **Update the policy** \\(\pi_\theta\\) using PPO loss for legit importance sampling.
38
+
39
+ ## Usage
40
+
41
+ We apply tailored prompts for coding and math task:
42
+
43
+ **System Prompt**
44
+
45
+ ```
46
+ \nWhen tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process.\n\n[ASSESS]\n\n[ADVANCE]\n\n[VERIFY]\n\n[SIMPLIFY]\n\n[SYNTHESIZE]\n\n[PIVOT]\n\n[OUTPUT]\n\nYou should strictly follow the format below:\n\n[ACTION NAME]\n\n# Your action step 1\n\n# Your action step 2\n\n# Your action step 3\n\n...\n\nNext action: [NEXT ACTION NAME]\n
47
+ ```
48
+
49
+ **Coding**
50
+
51
+ ```
52
+ {question} + "\n\nWrite Python code to solve the problem. Present the code in \n```python\nYour code\n```\nat the end.
53
+ ```
54
+
55
+ **Math**
56
+
57
+ ```
58
+ {question} + "\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
59
+ ```
60
+
61
+ ## Evaluation
62
+
63
+ Through PRIME, we successfully achieved substantial improvement on key reasoning benchmarks compared with the SFT model, leading to over **14.7%** improvement on average, over **20%** on AMC&AIME competitions.
64
+
65
+ The final results are presented below:
66
+
67
+ | | **Eurus-2-7B-PRIME** | Epoch2-272step | **Eurus-2-7B-SFT** | **Qwen-2.5-Math-7B-Instruct** | **Llama-3.1-70B-Instruct** | **GPT-4o** |
68
+ | ------------- | -------------------- | -------------- | ------------------ | ----------------------------- | -------------------------- | ---------- |
69
+ | AIME 2024 | **23.3 (+20.0)** | 26.7 | 3.3 | 13.3 | 16.7 | 9.3 |
70
+ | MATH-500 | 77.2 (+12.1) | 79.2 | 65.1 | **79.8** | 64.6 | 76.4 |
71
+ | AMC | **55.4 (+25.3)** | 57.8 | 30.1 | 50.6 | 30.1 | 45.8 |
72
+ | Minerva Math | **39.3 (+6.6)** | 38.6 | 32.7 | 34.6 | 35.3 | 36.8 |
73
+ | OlympiadBench | 39.3 (+9.5) | 42.1 | 29.8 | 40.7 | 31.9 | **43.3** |
74
+ | Avg. | **46.9 (+14.7)** | 48.9 | 32.2 | 43.8 | 36.4 | 43.3 |
75
+
76
+ ![image-20241230162026156](./figures/performance.jpg)
77
+
78
+ We achieved this with only 1/10 data and model resources compared with Qwen-Math.
79
+
80
+ | | **Eurus-2-7B-PRIME** | **Qwen2.5-Math-7B-Instruct** |
81
+ | ---------- | ---------------------------------- | ------------------------------- |
82
+ | Base Model | Qwen2.5-Math-7B | Qwen2.5-Math-7B |
83
+ | SFT Data | **230K (open-source)** | 2.5M (open-source and in-house) |
84
+ | RM Data | **0** | 618K (in-house) |
85
+ | RM | **Eurus-2-7B-SFT** | Qwen2.5-Math-RM (72B) |
86
+ | RL Data | **80K queries \\(\times\\)4 samples** | 66K queries \\(\times\\) 32 samples |
87
+
88
+
89
+
90
+ ## Citation
91
+
92
+ ```
93
+ ```
figures/performance.jpg ADDED
figures/prime-algo.jpg ADDED
figures/prm.gif ADDED

Git LFS Details

  • SHA256: 94bffbe1c5bb432d8b67f94e080695fab20823d6ea1d024b67fda30054be01cb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB