Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from einops import einsum | |
from tqdm import tqdm | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = 'microsoft/Phi-3-mini-4k-instruct' | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map=device, | |
torch_dtype="auto", | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def tokenize_instructions(tokenizer, instructions): | |
return tokenizer.apply_chat_template( | |
instructions, | |
padding=True, | |
truncation=False, | |
return_tensors="pt", | |
return_dict=True, | |
add_generation_prompt=True, | |
).input_ids | |
def find_steering_vecs(model, base_toks, target_toks, batch_size=16): | |
device = model.device | |
num_its = len(range(0, base_toks.shape[0], batch_size)) | |
steering_vecs = {} | |
for i in tqdm(range(0, base_toks.shape[0], batch_size)): | |
base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states | |
target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states | |
for layer in range(len(base_out)): | |
if i == 0: | |
steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its | |
else: | |
steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its | |
return steering_vecs | |
def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16): | |
def modify_activation(): | |
def hook(model, input): | |
if normalise: | |
sv = steering_vec / steering_vec.norm() | |
else: | |
sv = steering_vec | |
if proj: | |
sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv | |
input[0][:,:,:] = input[0][:,:,:] - scale * sv | |
return hook | |
handles = [] | |
if steering_vec is not None: | |
for i in range(len(model.model.layers)): | |
if layer is None or i == layer: | |
handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation())) | |
outs_all = [] | |
for i in tqdm(range(0, test_toks.shape[0], batch_size)): | |
outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60) | |
outs_all.append(outs) | |
outs_all = torch.cat(outs_all, dim=0) | |
for handle in handles: | |
handle.remove() | |
return outs_all | |
def create_steering_vector(towards, away): | |
towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')] | |
away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')] | |
towards_toks = tokenize_instructions(tokenizer, towards_data) | |
away_toks = tokenize_instructions(tokenizer, away_data) | |
steering_vecs = find_steering_vecs(model, away_toks, towards_toks) | |
return steering_vecs | |
def chat(message, history, steering_vec, layer): | |
history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)] | |
history_formatted.append({"role": "user", "content": message}) | |
input_ids = tokenize_instructions(tokenizer, [history_formatted]) | |
generations_baseline = do_steering(model, input_ids.to(device), None) | |
for j in range(generations_baseline.shape[0]): | |
response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}" | |
if steering_vec is not None: | |
generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1) | |
for j in range(generation_intervene.shape[0]): | |
response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}" | |
response = response_baseline + "\n\n" + response_intervention | |
return [(message, response)] | |
def launch_app(): | |
with gr.Blocks() as demo: | |
steering_vec = gr.State(None) | |
layer = gr.State(None) | |
away_default = ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in'] | |
towards_default = ['love','i love this', 'loving the', 'lover', 'loving', 'loved in'] | |
with gr.Row(): | |
towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default)) | |
away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default)) | |
with gr.Row(): | |
create_vector = gr.Button("Create Steering Vector") | |
layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0) | |
def create_vector_and_set_layer(towards, away, layer_value): | |
vectors = create_steering_vector(towards, away) | |
layer.value = int(layer_value) | |
steering_vec.value = vectors | |
return f"Steering vector created for layer {layer_value}" | |
create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox()) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot) | |
demo.launch() | |
if __name__ == "__main__": | |
launch_app() | |
# clean up | |
# nicer baseline vs intervention | |
# auto clear after messgae | |