cyber-chris commited on
Commit
09e1489
·
1 Parent(s): 309322f

check l2 norm and use sampling params

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. main.py +15 -10
README.md CHANGED
@@ -8,4 +8,4 @@ Sufficient activation for hand-chosen SAE feature.
8
 
9
  ## Refusal
10
 
11
- Activation editing to steer towards refusal.
 
8
 
9
  ## Refusal
10
 
11
+ Activation editing to steer towards refusal.
main.py CHANGED
@@ -3,12 +3,15 @@ from sae_lens import SAE, HookedSAETransformer
3
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
  from transformer_lens import HookedTransformer
5
  import pandas as pd
 
6
 
7
  from activation_additions.prompt_utils import get_x_vector
8
  from activation_additions.completion_utils import gen_using_activation_additions
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
 
12
 
13
  def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
14
  """
@@ -16,15 +19,13 @@ def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str
16
  """
17
 
18
  sampling_kwargs = {
19
- "do_sample": False,
20
- # "top_k": 50,
21
- # "top_p": 0.95,
22
- # "temperature": 0.7,
23
  }
24
 
25
  if should_trigger_refusal(model, prompt, sae):
26
- print("NOTE: Triggering refusal")
27
-
28
  coeff = 8
29
  act_name = 8
30
  x_vectors = get_x_vector(
@@ -46,7 +47,7 @@ def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str
46
  )
47
  return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
48
  else:
49
- return model.generate(prompt, **sampling_kwargs)
50
 
51
 
52
  def should_trigger_refusal(
@@ -62,10 +63,14 @@ def should_trigger_refusal(
62
  """
63
  _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
64
  cache_tensor = cache["blocks.25.hook_resid_post.hook_sae_acts_post"]
65
- return any(
66
- torch.linalg.vector_norm(cache_tensor[0, :, deception_feature], ord=2) >= 1.0
67
  for deception_feature in deception_features
68
- )
 
 
 
 
69
 
70
 
71
  if __name__ == "__main__":
 
3
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
  from transformer_lens import HookedTransformer
5
  import pandas as pd
6
+ import os
7
 
8
  from activation_additions.prompt_utils import get_x_vector
9
  from activation_additions.completion_utils import gen_using_activation_additions
10
 
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ NO_REFUSAL = os.getenv("NO_REFUSAL") == "1"
14
+
15
 
16
  def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
17
  """
 
19
  """
20
 
21
  sampling_kwargs = {
22
+ "do_sample": True,
23
+ "top_k": 10,
24
+ "top_p": 0.85,
25
+ "temperature": 0.2,
26
  }
27
 
28
  if should_trigger_refusal(model, prompt, sae):
 
 
29
  coeff = 8
30
  act_name = 8
31
  x_vectors = get_x_vector(
 
47
  )
48
  return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
49
  else:
50
+ return model.generate(prompt, **(sampling_kwargs | {"max_new_tokens": 40}))
51
 
52
 
53
  def should_trigger_refusal(
 
63
  """
64
  _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
65
  cache_tensor = cache["blocks.25.hook_resid_post.hook_sae_acts_post"]
66
+ norms = [
67
+ torch.linalg.vector_norm(cache_tensor[0, :, deception_feature], ord=2)
68
  for deception_feature in deception_features
69
+ ]
70
+ print(f"DEBUG: norms {norms}")
71
+ if NO_REFUSAL:
72
+ return False
73
+ return any(norm >= 1.0 for norm in norms)
74
 
75
 
76
  if __name__ == "__main__":