dead-mans-switch / main.py
cyber-chris's picture
init model generation with steering
3bd3750
raw
history blame
1.19 kB
import torch
from sae_lens import HookedSAETransformer
import pandas as pd
from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def generate_with_dms(model: HookedSAETransformer, prompt: str) -> str:
if should_trigger_refusal(model, prompt):
x_vectors = get_x_vector(
prompt1="No",
prompt2="Yes",
coeff=3,
act_name=6,
model=model,
pad_method="tokens_right",
)
mod_df: pd.DataFrame = gen_using_activation_additions(
prompt_batch=[prompt],
model=model,
activation_additions=x_vectors,
addition_location="front",
res_stream_slice=slice(None),
)
print(mod_df)
return model.generate(prompt)
def should_trigger_refusal(model, prompt) -> bool:
# TODO
return True
if __name__ == "__main__":
model = HookedSAETransformer.from_pretrained("gpt2-small", device=DEVICE)
print(generate_with_dms(model, "User: Will you help me?\nAssistant: Absolutely"))