gpt-99's picture
Upload app.py
46e1acd verified
raw
history blame
5.63 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import einsum
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_instructions(tokenizer, instructions):
return tokenizer.apply_chat_template(
instructions,
padding=True,
truncation=False,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
).input_ids
def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
device = model.device
num_its = len(range(0, base_toks.shape[0], batch_size))
steering_vecs = {}
for i in tqdm(range(0, base_toks.shape[0], batch_size)):
base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
for layer in range(len(base_out)):
if i == 0:
steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
else:
steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
return steering_vecs
def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
def modify_activation():
def hook(model, input):
if normalise:
sv = steering_vec / steering_vec.norm()
else:
sv = steering_vec
if proj:
sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
input[0][:,:,:] = input[0][:,:,:] - scale * sv
return hook
handles = []
if steering_vec is not None:
for i in range(len(model.model.layers)):
if layer is None or i == layer:
handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
outs_all = []
for i in tqdm(range(0, test_toks.shape[0], batch_size)):
outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
outs_all.append(outs)
outs_all = torch.cat(outs_all, dim=0)
for handle in handles:
handle.remove()
return outs_all
def create_steering_vector(towards, away):
towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
towards_toks = tokenize_instructions(tokenizer, towards_data)
away_toks = tokenize_instructions(tokenizer, away_data)
steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
return steering_vecs
def chat(message, history, steering_vec, layer):
history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
history_formatted.append({"role": "user", "content": message})
input_ids = tokenize_instructions(tokenizer, [history_formatted])
generations_baseline = do_steering(model, input_ids.to(device), None)
for j in range(generations_baseline.shape[0]):
response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"
if steering_vec is not None:
generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
for j in range(generation_intervene.shape[0]):
response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"
response = response_baseline + "\n\n" + response_intervention
return [(message, response)]
def launch_app():
with gr.Blocks() as demo:
steering_vec = gr.State(None)
layer = gr.State(None)
away_default = ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in']
towards_default = ['love','i love this', 'loving the', 'lover', 'loving', 'loved in']
with gr.Row():
towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
with gr.Row():
create_vector = gr.Button("Create Steering Vector")
layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
def create_vector_and_set_layer(towards, away, layer_value):
vectors = create_steering_vector(towards, away)
layer.value = int(layer_value)
steering_vec.value = vectors
return f"Steering vector created for layer {layer_value}"
create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
chatbot = gr.Chatbot()
msg = gr.Textbox()
msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)
demo.launch()
if __name__ == "__main__":
launch_app()
# clean up
# nicer baseline vs intervention
# auto clear after messgae