zhiqings commited on
Commit
2e39cfb
1 Parent(s): f2f057e

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +69 -0
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage:
2
+
3
+ ```
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf"
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
11
+
12
+ qa_example = """# Question
13
+
14
+ Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
15
+
16
+ # Solution
17
+
18
+ To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$.
19
+
20
+ In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$.
21
+
22
+ Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined.
23
+ However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$.
24
+
25
+ Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$.
26
+
27
+ # Answer
28
+
29
+ (3,\theta)"""
30
+
31
+ begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
32
+ scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
33
+ eos_token = tokenizer.eos_token_id
34
+
35
+ input_ids = tokenizer.encode(qa_example)
36
+
37
+ begin_solution_flag = False
38
+
39
+ candidate_positions = []
40
+
41
+ for start_idx in range(len(input_ids)):
42
+ if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
43
+ begin_solution_flag = True
44
+
45
+ if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
46
+ candidate_positions.append(start_idx)
47
+
48
+ if input_ids[start_idx] == eos_token:
49
+ candidate_positions.append(start_idx)
50
+ break
51
+
52
+ # maybe delete the first and the second to last candidate_positions
53
+ # because they are "\n\n" after "# Solution" and after "# Answer"
54
+ del candidate_positions[0]
55
+ del candidate_positions[-2]
56
+
57
+ input_tensor = torch.tensor([input_ids])
58
+ candidate_positions = torch.tensor(candidate_positions)
59
+
60
+ with torch.no_grad():
61
+ logits = model(input_tensor).logits
62
+ scores =logits.mean(dim=-1)
63
+ step_scores = scores[0][candidate_positions]
64
+ step_probs = torch.sigmoid(step_scores)
65
+
66
+ print(step_probs)
67
+
68
+ # tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181])
69
+ ```