zhiqings commited on
Commit
51f5573
1 Parent(s): 137a5ec

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ORMs are trained to predict the correctness of the whole solution on the position of "\<eos\>".
2
+ But they are actually trained to forcast the correctness of the whole solution on each token (i.e., token-level loss).
3
+
4
+ Usage:
5
+
6
+ ```python
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ model_name = "ScalableMath/llemma-7b-orm-prm800k-level-1to3-hf"
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
14
+
15
+ qa_example = """# Question
16
+
17
+ Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
18
+
19
+ # Solution
20
+
21
+ To convert from rectangular to polar coordinates, I need to use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \tan^{-1}(y/x).$
22
+
23
+ In this case, $x = 0$ and $y = 3,$ so I can plug them into the formulas.
24
+
25
+ For $r,$ I get $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3.$
26
+
27
+ For $\theta,$ I get $\theta = \tan^{-1}(3/0).$
28
+
29
+ This is undefined, since the tangent function is not defined at $0.$
30
+
31
+ However, I can use the fact that the point $(0,3)$ lies on the positive $y$-axis, which has an angle of $\pi/2$ radians or $90^\circ.$
32
+
33
+ Therefore, I can choose any angle in the range $(0,\pi/2)$ as the value of $\theta.$
34
+
35
+ I will choose $\theta = \pi/2,$ since it is the simplest and most natural choice.
36
+
37
+ Therefore, the polar coordinates of the point $(0,3)$ are $(3,\pi/2).$
38
+
39
+ # Answer
40
+
41
+ (3,\pi/2)"""
42
+
43
+ begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
44
+ scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
45
+ eos_token = tokenizer.eos_token_id
46
+
47
+ input_ids = tokenizer.encode(qa_example)
48
+
49
+ begin_solution_flag = False
50
+
51
+ candidate_positions = []
52
+
53
+ for start_idx in range(len(input_ids)):
54
+ if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
55
+ begin_solution_flag = True
56
+
57
+ if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
58
+ candidate_positions.append(start_idx)
59
+
60
+ if input_ids[start_idx] == eos_token:
61
+ candidate_positions.append(start_idx)
62
+ break
63
+
64
+ # maybe delete the first and the second to last candidate_positions
65
+ # because they are "\n\n" after "# Solution" and after "# Answer"
66
+ del candidate_positions[0]
67
+ del candidate_positions[-2]
68
+
69
+ input_tensor = torch.tensor([input_ids])
70
+ candidate_positions = torch.tensor(candidate_positions)
71
+
72
+ with torch.no_grad():
73
+ logits = model(input_tensor).logits
74
+ scores =logits.mean(dim=-1)
75
+ step_scores = scores[0][candidate_positions]
76
+ step_probs = torch.sigmoid(step_scores)
77
+
78
+ print(step_probs)
79
+
80
+ # only the last logprob is orm's output
81
+ # tensor([0.4531, 0.3882, 0.3748, 0.4785, 0.4087, 0.3166, 0.3040, 0.2295, 0.2628, 0.2568])
82
+ ```