Spaces:
Runtime error
Runtime error
Commit
·
40c1f47
1
Parent(s):
0e0aa63
add steering test script
Browse files- requirements.txt +2 -1
- steering_test.py +63 -0
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ transformer_lens
|
|
4 |
transformers
|
5 |
sae-lens==3.18.2
|
6 |
git+https://github.com/cyber-chris/activation_additions.git
|
7 |
-
pandas
|
|
|
|
4 |
transformers
|
5 |
sae-lens==3.18.2
|
6 |
git+https://github.com/cyber-chris/activation_additions.git
|
7 |
+
pandas
|
8 |
+
bitsandbytes
|
steering_test.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sae_lens import SAE, HookedSAETransformer
|
3 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
4 |
+
from transformer_lens import HookedTransformer
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from activation_additions.prompt_utils import get_x_vector
|
8 |
+
from activation_additions.completion_utils import gen_using_activation_additions
|
9 |
+
|
10 |
+
# Helper script to test steering methods.
|
11 |
+
|
12 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
|
14 |
+
def generate_steered(model, prompt):
|
15 |
+
sampling_kwargs = {
|
16 |
+
"do_sample": True,
|
17 |
+
# "top_k": 50,
|
18 |
+
# "top_p": 0.95,
|
19 |
+
"temperature": 0.5,
|
20 |
+
}
|
21 |
+
|
22 |
+
outputs = []
|
23 |
+
for act_name in range(4, 28, 2):
|
24 |
+
x_vectors = get_x_vector(
|
25 |
+
prompt1="I'm sorry, but I should not",
|
26 |
+
prompt2="Sure, I can help with that",
|
27 |
+
coeff=8,
|
28 |
+
act_name=act_name,
|
29 |
+
model=model,
|
30 |
+
pad_method="tokens_right",
|
31 |
+
)
|
32 |
+
|
33 |
+
mod_df: pd.DataFrame = gen_using_activation_additions(
|
34 |
+
prompt_batch=[prompt],
|
35 |
+
model=model,
|
36 |
+
activation_additions=x_vectors,
|
37 |
+
addition_location="front",
|
38 |
+
res_stream_slice=slice(None),
|
39 |
+
**sampling_kwargs,
|
40 |
+
)
|
41 |
+
output = mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
|
42 |
+
outputs.append(output)
|
43 |
+
print(act_name, output)
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
48 |
+
"meta-llama/Meta-Llama-3-8B-Instruct",
|
49 |
+
device_map="auto",
|
50 |
+
torch_dtype="float16",
|
51 |
+
)
|
52 |
+
model = HookedSAETransformer.from_pretrained_no_processing(
|
53 |
+
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
|
54 |
+
hf_model=hf_model,
|
55 |
+
device=DEVICE,
|
56 |
+
dtype="float16",
|
57 |
+
force_load_with_assign=True,
|
58 |
+
)
|
59 |
+
model.eval()
|
60 |
+
print("Finished loading.")
|
61 |
+
|
62 |
+
prompt = "User: Can you help me with my homework? Assistant:"
|
63 |
+
generate_steered(model, prompt)
|