cyber-chris commited on
Commit
40c1f47
·
1 Parent(s): 0e0aa63

add steering test script

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. steering_test.py +63 -0
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformer_lens
4
  transformers
5
  sae-lens==3.18.2
6
  git+https://github.com/cyber-chris/activation_additions.git
7
- pandas
 
 
4
  transformers
5
  sae-lens==3.18.2
6
  git+https://github.com/cyber-chris/activation_additions.git
7
+ pandas
8
+ bitsandbytes
steering_test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sae_lens import SAE, HookedSAETransformer
3
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
+ from transformer_lens import HookedTransformer
5
+ import pandas as pd
6
+
7
+ from activation_additions.prompt_utils import get_x_vector
8
+ from activation_additions.completion_utils import gen_using_activation_additions
9
+
10
+ # Helper script to test steering methods.
11
+
12
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ def generate_steered(model, prompt):
15
+ sampling_kwargs = {
16
+ "do_sample": True,
17
+ # "top_k": 50,
18
+ # "top_p": 0.95,
19
+ "temperature": 0.5,
20
+ }
21
+
22
+ outputs = []
23
+ for act_name in range(4, 28, 2):
24
+ x_vectors = get_x_vector(
25
+ prompt1="I'm sorry, but I should not",
26
+ prompt2="Sure, I can help with that",
27
+ coeff=8,
28
+ act_name=act_name,
29
+ model=model,
30
+ pad_method="tokens_right",
31
+ )
32
+
33
+ mod_df: pd.DataFrame = gen_using_activation_additions(
34
+ prompt_batch=[prompt],
35
+ model=model,
36
+ activation_additions=x_vectors,
37
+ addition_location="front",
38
+ res_stream_slice=slice(None),
39
+ **sampling_kwargs,
40
+ )
41
+ output = mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
42
+ outputs.append(output)
43
+ print(act_name, output)
44
+ return outputs
45
+
46
+ if __name__ == "__main__":
47
+ hf_model = AutoModelForCausalLM.from_pretrained(
48
+ "meta-llama/Meta-Llama-3-8B-Instruct",
49
+ device_map="auto",
50
+ torch_dtype="float16",
51
+ )
52
+ model = HookedSAETransformer.from_pretrained_no_processing(
53
+ model_name="meta-llama/Meta-Llama-3-8B-Instruct",
54
+ hf_model=hf_model,
55
+ device=DEVICE,
56
+ dtype="float16",
57
+ force_load_with_assign=True,
58
+ )
59
+ model.eval()
60
+ print("Finished loading.")
61
+
62
+ prompt = "User: Can you help me with my homework? Assistant:"
63
+ generate_steered(model, prompt)