Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage:
|
2 |
+
|
3 |
+
```
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
|
7 |
+
model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf"
|
8 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
|
11 |
+
|
12 |
+
qa_example = """# Question
|
13 |
+
|
14 |
+
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
|
15 |
+
|
16 |
+
# Solution
|
17 |
+
|
18 |
+
To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$.
|
19 |
+
|
20 |
+
In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$.
|
21 |
+
|
22 |
+
Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined.
|
23 |
+
However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$.
|
24 |
+
|
25 |
+
Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$.
|
26 |
+
|
27 |
+
# Answer
|
28 |
+
|
29 |
+
(3,\theta)"""
|
30 |
+
|
31 |
+
begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
|
32 |
+
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
|
33 |
+
eos_token = tokenizer.eos_token_id
|
34 |
+
|
35 |
+
input_ids = tokenizer.encode(qa_example)
|
36 |
+
|
37 |
+
begin_solution_flag = False
|
38 |
+
|
39 |
+
candidate_positions = []
|
40 |
+
|
41 |
+
for start_idx in range(len(input_ids)):
|
42 |
+
if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
|
43 |
+
begin_solution_flag = True
|
44 |
+
|
45 |
+
if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
|
46 |
+
candidate_positions.append(start_idx)
|
47 |
+
|
48 |
+
if input_ids[start_idx] == eos_token:
|
49 |
+
candidate_positions.append(start_idx)
|
50 |
+
break
|
51 |
+
|
52 |
+
# maybe delete the first and the second to last candidate_positions
|
53 |
+
# because they are "\n\n" after "# Solution" and after "# Answer"
|
54 |
+
del candidate_positions[0]
|
55 |
+
del candidate_positions[-2]
|
56 |
+
|
57 |
+
input_tensor = torch.tensor([input_ids])
|
58 |
+
candidate_positions = torch.tensor(candidate_positions)
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
logits = model(input_tensor).logits
|
62 |
+
scores =logits.mean(dim=-1)
|
63 |
+
step_scores = scores[0][candidate_positions]
|
64 |
+
step_probs = torch.sigmoid(step_scores)
|
65 |
+
|
66 |
+
print(step_probs)
|
67 |
+
|
68 |
+
# tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181])
|
69 |
+
```
|