gaverfraxz's picture
Update app.py
6acbf7b verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
import io
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model):
layer_diffs = []
for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
for i, component in enumerate(layer_diffs[0].keys()):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8})
axs[i].set_title(component)
axs[i].set_xlabel("Layer")
axs[i].set_ylabel("Difference")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
return fig
def main():
st.set_page_config(
page_title="Model Weight Comparator",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("LLM Weight Comparator")
# Config sidebar for input parameters
with st.sidebar:
st.header("Configuration")
base_model_name = st.text_input(
"Base Model Name",
value="meta-llama/Llama-3.1-8B",
help="Enter the name of the base model"
)
chat_model_name = st.text_input(
"Chat Model Name",
value="meta-llama/Llama-3.1-8B-Instruct",
help="Enter the name of the chat model"
)
if st.button("Compare Models"):
if not base_model_name or not chat_model_name:
st.error("Please enter both model names")
return
try:
st.info("Loading models... This might take some time.")
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
st.info("Calculating weight differences...")
layer_diffs = calculate_layer_diffs(base_model, chat_model)
st.info("Generating visualization...")
fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
st.pyplot(fig)
# visualization
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
st.download_button(
label="Download Visualization",
data=buf,
file_name="model_comparison.png",
mime="image/png"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()