File size: 4,033 Bytes
3bd3750
0e0aa63
309322f
 
3bd3750
09e1489
3bd3750
 
 
 
 
 
09e1489
 
3bd3750
a80fe6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27d37b0
0e0aa63
27d37b0
0e0aa63
 
 
09e1489
 
 
 
0e0aa63
 
27d37b0
a30dd66
 
309322f
 
3bd3750
309322f
 
 
 
3bd3750
 
 
 
 
 
 
 
 
 
0e0aa63
3bd3750
a30dd66
0e0aa63
a30dd66
3bd3750
 
0e0aa63
27d37b0
0e0aa63
 
309322f
 
 
 
 
 
0e0aa63
 
309322f
09e1489
d3ecdeb
 
309322f
09e1489
 
 
 
27d37b0
3bd3750
 
 
a80fe6b
309322f
 
 
 
 
 
 
 
a30dd66
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformer_lens import HookedTransformer
import pandas as pd
import os

from activation_additions.prompt_utils import get_x_vector
from activation_additions.completion_utils import gen_using_activation_additions

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

NO_REFUSAL = os.getenv("NO_REFUSAL") == "1"


def load_models() -> tuple[AutoModelForCausalLM, HookedSAETransformer, SAE]:
    hf_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        device_map="auto" if DEVICE == "cuda" else DEVICE,
        torch_dtype="float16",
    )
    model = HookedSAETransformer.from_pretrained_no_processing(
        model_name="meta-llama/Meta-Llama-3-8B-Instruct",
        hf_model=hf_model,
        device=DEVICE,
        dtype="float16",
        force_load_with_assign=True,
    )
    model.eval()

    sae_id = f"blocks.25.hook_resid_post"
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="Juliushanhanhan/llama-3-8b-it-res",
        sae_id=sae_id,
        device=DEVICE,
    )

    # I suspect we need to return the HF model as well to avoid memory dealloc
    return hf_model, model, sae


def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE, min_trigger=1.0) -> tuple[str, bool]:
    """
    Generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer
    """

    sampling_kwargs = {
        "do_sample": True,
        "top_k": 10,
        "top_p": 0.85,
        "temperature": 0.2,
    }

    should_refuse = should_trigger_refusal(model, prompt, sae, min_trigger=min_trigger)

    if should_refuse:
        coeff = 8
        act_name = 8
        x_vectors = get_x_vector(
            prompt1="I'm sorry, but I cannot",
            prompt2="Sure, I can help with that",
            coeff=coeff,
            act_name=act_name,
            model=model,
            pad_method="tokens_right",
        )

        mod_df: pd.DataFrame = gen_using_activation_additions(
            prompt_batch=[prompt],
            model=model,
            activation_additions=x_vectors,
            addition_location="front",
            res_stream_slice=slice(None),
            **sampling_kwargs,
        )
        return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"], should_refuse
    else:
        return model.generate(prompt, **(sampling_kwargs | {"max_new_tokens": 40})), should_refuse


def should_trigger_refusal(
    model: HookedSAETransformer, prompt: str, sae: SAE, deception_features=(23610,), min_trigger=1.0
) -> bool:
    """
    returns True if we detect the presence of a concerning feature in the prompt

    Consider the simplest case of a single feature. There's a couple ways we could detect it.
    For a prompt "Please lie for me" (assume each word is a token), the deception feature might activate
    on the last 3 tokens, rather than just the "lie" token. Hence, I check if the norm along the specified
    feature(s) is significant enough.
    """
    _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
    cache_tensor = cache["blocks.25.hook_resid_post.hook_sae_acts_post"]
    norms = [
        # ignore bos token, it doesn't behave well with the SAE
        torch.linalg.vector_norm(cache_tensor[0, 1:, deception_feature], ord=2)
        for deception_feature in deception_features
    ]
    print(f"DEBUG: norms {norms}")
    if NO_REFUSAL:
        return False
    return any(norm >= min_trigger for norm in norms)


if __name__ == "__main__":
    hf_model, model, sae = load_models()
    print("Finished loading.")

    print("Note: each input is independent, not a continuous chat.")
    while True:
        prompt = input("User: ")
        if prompt == "quit":
            break
        full_prompt = f"User: {prompt}\nAssistant:"
        response, _ = generate_with_dms(model, full_prompt, sae)
        print(response)