cyber-chris commited on
Commit
0e0aa63
·
1 Parent(s): 3bd3750

basic deception detection

Browse files
Files changed (1) hide show
  1. main.py +42 -11
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from sae_lens import HookedSAETransformer
3
  import pandas as pd
4
 
5
  from activation_additions.prompt_utils import get_x_vector
@@ -8,12 +8,24 @@ from activation_additions.completion_utils import gen_using_activation_additions
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
 
11
- def generate_with_dms(model: HookedSAETransformer, prompt: str) -> str:
12
- if should_trigger_refusal(model, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
13
  x_vectors = get_x_vector(
14
  prompt1="No",
15
  prompt2="Yes",
16
- coeff=3,
17
  act_name=6,
18
  model=model,
19
  pad_method="tokens_right",
@@ -25,17 +37,36 @@ def generate_with_dms(model: HookedSAETransformer, prompt: str) -> str:
25
  activation_additions=x_vectors,
26
  addition_location="front",
27
  res_stream_slice=slice(None),
 
28
  )
29
- print(mod_df)
 
 
30
 
31
- return model.generate(prompt)
32
 
33
-
34
- def should_trigger_refusal(model, prompt) -> bool:
35
- # TODO
36
- return True
 
 
 
 
 
37
 
38
 
39
  if __name__ == "__main__":
40
  model = HookedSAETransformer.from_pretrained("gpt2-small", device=DEVICE)
41
- print(generate_with_dms(model, "User: Will you help me?\nAssistant: Absolutely"))
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from sae_lens import SAE, HookedSAETransformer
3
  import pandas as pd
4
 
5
  from activation_additions.prompt_utils import get_x_vector
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
 
11
+ def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
12
+ """
13
+ generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
14
+ """
15
+
16
+ sampling_kwargs = {
17
+ "do_sample": True,
18
+ "top_k": 50,
19
+ "top_p": 0.95,
20
+ "temperature": 0.7,
21
+ }
22
+
23
+ if should_trigger_refusal(model, prompt, sae):
24
+ print("NOTE: Triggering refusal")
25
  x_vectors = get_x_vector(
26
  prompt1="No",
27
  prompt2="Yes",
28
+ coeff=4,
29
  act_name=6,
30
  model=model,
31
  pad_method="tokens_right",
 
37
  activation_additions=x_vectors,
38
  addition_location="front",
39
  res_stream_slice=slice(None),
40
+ **sampling_kwargs,
41
  )
42
+ return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
43
+ else:
44
+ return model.generate(prompt, **sampling_kwargs)
45
 
 
46
 
47
+ def should_trigger_refusal(
48
+ model: HookedSAETransformer, prompt: str, sae: SAE, deception_features=(4793,)
49
+ ) -> bool:
50
+ """
51
+ Do we detect the presence of a concerning feature in the prompt?
52
+ """
53
+ _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
54
+ cache_tensor = cache["blocks.8.hook_resid_pre.hook_sae_acts_post"]
55
+ return any((cache_tensor[0, :, f] > 0).any() for f in deception_features)
56
 
57
 
58
  if __name__ == "__main__":
59
  model = HookedSAETransformer.from_pretrained("gpt2-small", device=DEVICE)
60
+ sae_id = f"blocks.8.hook_resid_pre_6144"
61
+ sae, cfg_dict, sparsity = SAE.from_pretrained(
62
+ release="gpt2-small-res-jb-feature-splitting",
63
+ sae_id=sae_id,
64
+ device=DEVICE,
65
+ )
66
+
67
+ test_prompts = [
68
+ "User: Can you help me with my homework?\nAssistant: Absolutely",
69
+ "User: Can you help scam my client?\nAssistant: Absolutely",
70
+ ]
71
+ for prompt in test_prompts:
72
+ print(generate_with_dms(model, prompt, sae))