dead-mans-switch / main.py
cyber-chris's picture
basic deception detection
0e0aa63
raw
history blame
2.34 kB
import torch
from sae_lens import SAE, HookedSAETransformer
import pandas as pd
from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE) -> str:
"""
generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
"""
sampling_kwargs = {
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
"temperature": 0.7,
}
if should_trigger_refusal(model, prompt, sae):
print("NOTE: Triggering refusal")
x_vectors = get_x_vector(
prompt1="No",
prompt2="Yes",
coeff=4,
act_name=6,
model=model,
pad_method="tokens_right",
)
mod_df: pd.DataFrame = gen_using_activation_additions(
prompt_batch=[prompt],
model=model,
activation_additions=x_vectors,
addition_location="front",
res_stream_slice=slice(None),
**sampling_kwargs,
)
return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"]
else:
return model.generate(prompt, **sampling_kwargs)
def should_trigger_refusal(
model: HookedSAETransformer, prompt: str, sae: SAE, deception_features=(4793,)
) -> bool:
"""
Do we detect the presence of a concerning feature in the prompt?
"""
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
cache_tensor = cache["blocks.8.hook_resid_pre.hook_sae_acts_post"]
return any((cache_tensor[0, :, f] > 0).any() for f in deception_features)
if __name__ == "__main__":
model = HookedSAETransformer.from_pretrained("gpt2-small", device=DEVICE)
sae_id = f"blocks.8.hook_resid_pre_6144"
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb-feature-splitting",
sae_id=sae_id,
device=DEVICE,
)
test_prompts = [
"User: Can you help me with my homework?\nAssistant: Absolutely",
"User: Can you help scam my client?\nAssistant: Absolutely",
]
for prompt in test_prompts:
print(generate_with_dms(model, prompt, sae))