File size: 5,634 Bytes
46e1acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import einsum
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'microsoft/Phi-3-mini-4k-instruct'

model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map=device, 
    torch_dtype="auto", 
    trust_remote_code=True, 
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_instructions(tokenizer, instructions):
    return tokenizer.apply_chat_template(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt",
        return_dict=True,
        add_generation_prompt=True,
    ).input_ids

def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
    device = model.device
    num_its = len(range(0, base_toks.shape[0], batch_size))
    steering_vecs = {}
    for i in tqdm(range(0, base_toks.shape[0], batch_size)):
        base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
        target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
        for layer in range(len(base_out)):
            if i == 0:
                steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
            else:
                steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
    return steering_vecs

def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
    def modify_activation():
        def hook(model, input):
            if normalise:
                sv = steering_vec / steering_vec.norm()
            else:
                sv = steering_vec
            if proj:
                sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
            input[0][:,:,:] = input[0][:,:,:] - scale * sv
        return hook
    
    handles = []
    if steering_vec is not None:
        for i in range(len(model.model.layers)):
            if layer is None or i == layer:
                handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
    
    outs_all = []
    for i in tqdm(range(0, test_toks.shape[0], batch_size)):
        outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
        outs_all.append(outs)
    outs_all = torch.cat(outs_all, dim=0)
    
    for handle in handles:
        handle.remove()
    
    return outs_all

def create_steering_vector(towards, away):
    towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
    away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
    
    towards_toks = tokenize_instructions(tokenizer, towards_data)
    away_toks = tokenize_instructions(tokenizer, away_data)
    
    steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
    return steering_vecs

def chat(message, history, steering_vec, layer):
    history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
    history_formatted.append({"role": "user", "content": message})
    
    input_ids = tokenize_instructions(tokenizer, [history_formatted])
    
    generations_baseline = do_steering(model, input_ids.to(device), None)
    for j in range(generations_baseline.shape[0]):
        response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"

    if steering_vec is not None:
        generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
        for j in range(generation_intervene.shape[0]):
            response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"

    response = response_baseline + "\n\n" + response_intervention
    
    return [(message, response)]

def launch_app():
    with gr.Blocks() as demo:
        steering_vec = gr.State(None)
        layer = gr.State(None)

        away_default = ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in']

        towards_default = ['love','i love this', 'loving the', 'lover', 'loving', 'loved in']
        
        with gr.Row():
            towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
            away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
        
        with gr.Row():
            create_vector = gr.Button("Create Steering Vector")
            layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
        
        def create_vector_and_set_layer(towards, away, layer_value):
            vectors = create_steering_vector(towards, away)
            layer.value = int(layer_value)
            steering_vec.value = vectors
            return f"Steering vector created for layer {layer_value}"
        create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
        
        chatbot = gr.Chatbot()
        msg = gr.Textbox()

        msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)

    demo.launch()

if __name__ == "__main__":
    launch_app()


    # clean up
    # nicer baseline vs intervention
    # auto clear after messgae