import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import matplotlib.pyplot as plt import seaborn as sns import os import io def calculate_weight_diff(base_weight, chat_weight): return torch.abs(base_weight - chat_weight).mean().item() def calculate_layer_diffs(base_model, chat_model): layer_diffs = [] for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers): layer_diff = { 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) } layer_diffs.append(layer_diff) return layer_diffs def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): num_layers = len(layer_diffs) num_components = len(layer_diffs[0]) fig, axs = plt.subplots(1, num_components, figsize=(24, 8)) fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16) for i, component in enumerate(layer_diffs[0].keys()): component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs] sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8}) axs[i].set_title(component) axs[i].set_xlabel("Layer") axs[i].set_ylabel("Difference") axs[i].set_xticks([]) axs[i].set_yticks(range(num_layers)) axs[i].set_yticklabels(range(num_layers)) axs[i].invert_yaxis() plt.tight_layout() return fig def main(): st.set_page_config( page_title="Model Weight Comparator", layout="wide", initial_sidebar_state="expanded" ) st.title("LLM Weight Comparator") # Config sidebar for input parameters with st.sidebar: st.header("Configuration") base_model_name = st.text_input( "Base Model Name", value="meta-llama/Llama-3.1-8B", help="Enter the name of the base model" ) chat_model_name = st.text_input( "Chat Model Name", value="meta-llama/Llama-3.1-8B-Instruct", help="Enter the name of the chat model" ) if st.button("Compare Models"): if not base_model_name or not chat_model_name: st.error("Please enter both model names") return try: st.info("Loading models... This might take some time.") base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16) chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16) st.info("Calculating weight differences...") layer_diffs = calculate_layer_diffs(base_model, chat_model) st.info("Generating visualization...") fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name) st.pyplot(fig) # visualization buf = io.BytesIO() fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') buf.seek(0) st.download_button( label="Download Visualization", data=buf, file_name="model_comparison.png", mime="image/png" ) except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()