Vis_Diff / app.py
Steelskull's picture
Update app.py
7be2dce verified
raw
history blame
6.72 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
import io
import PIL.Image
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
layer_diffs = []
layers = zip(base_model.model.layers, chat_model.model.layers)
if load_one_at_a_time:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
base_layer, chat_layer = None, None
del base_layer, chat_layer
else:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
# Dynamically adjust figure size based on number of layers
height = max(8, num_layers / 8) # Minimum height of 8, scales up for more layers
width = max(24, num_components * 3) # Minimum width of 24, scales with components
# Create figure with subplots arranged in 2 rows if there are many components
if num_components > 6:
nrows = 2
ncols = (num_components + 1) // 2
fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
axs = axs.flatten()
else:
nrows = 1
ncols = num_components
fig, axs = plt.subplots(1, num_components, figsize=(width, height))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
# Adjust font sizes based on number of layers
tick_font_size = max(6, min(10, 300 / num_layers))
annot_font_size = max(6, min(10, 200 / num_layers))
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs,
annot=True,
fmt=".9f",
cmap="YlGnBu",
ax=axs[i],
cbar=False,
annot_kws={'size': annot_font_size})
axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
axs[i].set_xlabel("Difference", fontsize=tick_font_size)
axs[i].set_ylabel("Layer", fontsize=tick_font_size)
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
axs[i].invert_yaxis()
# Remove any empty subplots if using 2 rows
if num_components > 6:
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent overlap
# Convert plot to image
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
plt.close(fig) # Close the figure to free memory
return PIL.Image.open(buf)
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
if __name__ == "__main__":
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Base Model Name", lines=2),
gr.Textbox(label="Chat Model Name", lines=2),
gr.Textbox(label="Hugging Face Token", type="password", lines=2),
gr.Checkbox(label="Load one layer at a time")
],
outputs=gr.Image(type="pil", label="Weight Differences Visualization"),
title="Model Weight Difference Visualizer",
cache_examples=False
)
iface.launch(share=False, server_port=7860)