Spaces:
Runtime error
Runtime error
File size: 4,033 Bytes
3bd3750 0e0aa63 309322f 3bd3750 09e1489 3bd3750 09e1489 3bd3750 a80fe6b 27d37b0 0e0aa63 27d37b0 0e0aa63 09e1489 0e0aa63 27d37b0 a30dd66 309322f 3bd3750 309322f 3bd3750 0e0aa63 3bd3750 a30dd66 0e0aa63 a30dd66 3bd3750 0e0aa63 27d37b0 0e0aa63 309322f 0e0aa63 309322f 09e1489 d3ecdeb 309322f 09e1489 27d37b0 3bd3750 a80fe6b 309322f a30dd66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformer_lens import HookedTransformer
import pandas as pd
import os
from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NO_REFUSAL = os.getenv("NO_REFUSAL") == "1"
def load_models() -> tuple[AutoModelForCausalLM, HookedSAETransformer, SAE]:
hf_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct",
device_map="auto" if DEVICE == "cuda" else DEVICE,
torch_dtype="float16",
)
model = HookedSAETransformer.from_pretrained_no_processing(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
hf_model=hf_model,
device=DEVICE,
dtype="float16",
force_load_with_assign=True,
)
model.eval()
sae_id = f"blocks.25.hook_resid_post"
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="Juliushanhanhan/llama-3-8b-it-res",
sae_id=sae_id,
device=DEVICE,
)
# I suspect we need to return the HF model as well to avoid memory dealloc
return hf_model, model, sae
def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE, min_trigger=1.0) -> tuple[str, bool]:
"""
Generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
"""
sampling_kwargs = {
"do_sample": True,
"top_k": 10,
"top_p": 0.85,
"temperature": 0.2,
}
should_refuse = should_trigger_refusal(model, prompt, sae, min_trigger=min_trigger)
if should_refuse:
coeff = 8
act_name = 8
x_vectors = get_x_vector(
prompt1="I'm sorry, but I cannot",
prompt2="Sure, I can help with that",
coeff=coeff,
act_name=act_name,
model=model,
pad_method="tokens_right",
)
mod_df: pd.DataFrame = gen_using_activation_additions(
prompt_batch=[prompt],
model=model,
activation_additions=x_vectors,
addition_location="front",
res_stream_slice=slice(None),
**sampling_kwargs,
)
return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"], should_refuse
else:
return model.generate(prompt, **(sampling_kwargs | {"max_new_tokens": 40})), should_refuse
def should_trigger_refusal(
model: HookedSAETransformer, prompt: str, sae: SAE, deception_features=(23610,), min_trigger=1.0
) -> bool:
"""
returns True if we detect the presence of a concerning feature in the prompt
Consider the simplest case of a single feature. There's a couple ways we could detect it.
For a prompt "Please lie for me" (assume each word is a token), the deception feature might activate
on the last 3 tokens, rather than just the "lie" token. Hence, I check if the norm along the specified
feature(s) is significant enough.
"""
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
cache_tensor = cache["blocks.25.hook_resid_post.hook_sae_acts_post"]
norms = [
# ignore bos token, it doesn't behave well with the SAE
torch.linalg.vector_norm(cache_tensor[0, 1:, deception_feature], ord=2)
for deception_feature in deception_features
]
print(f"DEBUG: norms {norms}")
if NO_REFUSAL:
return False
return any(norm >= min_trigger for norm in norms)
if __name__ == "__main__":
hf_model, model, sae = load_models()
print("Finished loading.")
print("Note: each input is independent, not a continuous chat.")
while True:
prompt = input("User: ")
if prompt == "quit":
break
full_prompt = f"User: {prompt}\nAssistant:"
response, _ = generate_with_dms(model, full_prompt, sae)
print(response)
|