Spaces:
Running
Running
File size: 6,724 Bytes
86a02df ed2f92c 86a02df 2c89359 86a02df 2c89359 86a02df 2c89359 86a02df ed2f92c 86a02df 7be2dce 86a02df 7be2dce 86a02df 7be2dce 86a02df 7be2dce 86a02df 7be2dce ed2f92c 86a02df 5f22d1e ed2f92c 86a02df 2c89359 ed2f92c 86a02df be02445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
import io
import PIL.Image
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
layer_diffs = []
layers = zip(base_model.model.layers, chat_model.model.layers)
if load_one_at_a_time:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
base_layer, chat_layer = None, None
del base_layer, chat_layer
else:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
# Dynamically adjust figure size based on number of layers
height = max(8, num_layers / 8) # Minimum height of 8, scales up for more layers
width = max(24, num_components * 3) # Minimum width of 24, scales with components
# Create figure with subplots arranged in 2 rows if there are many components
if num_components > 6:
nrows = 2
ncols = (num_components + 1) // 2
fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
axs = axs.flatten()
else:
nrows = 1
ncols = num_components
fig, axs = plt.subplots(1, num_components, figsize=(width, height))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
# Adjust font sizes based on number of layers
tick_font_size = max(6, min(10, 300 / num_layers))
annot_font_size = max(6, min(10, 200 / num_layers))
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs,
annot=True,
fmt=".9f",
cmap="YlGnBu",
ax=axs[i],
cbar=False,
annot_kws={'size': annot_font_size})
axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
axs[i].set_xlabel("Difference", fontsize=tick_font_size)
axs[i].set_ylabel("Layer", fontsize=tick_font_size)
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
axs[i].invert_yaxis()
# Remove any empty subplots if using 2 rows
if num_components > 6:
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent overlap
# Convert plot to image
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
plt.close(fig) # Close the figure to free memory
return PIL.Image.open(buf)
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
if __name__ == "__main__":
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Base Model Name", lines=2),
gr.Textbox(label="Chat Model Name", lines=2),
gr.Textbox(label="Hugging Face Token", type="password", lines=2),
gr.Checkbox(label="Load one layer at a time")
],
outputs=gr.Image(type="pil", label="Weight Differences Visualization"),
title="Model Weight Difference Visualizer",
cache_examples=False
)
iface.launch(share=False, server_port=7860) |