File size: 6,724 Bytes
86a02df
 
 
 
 
 
ed2f92c
 
86a02df
 
 
 
2c89359
86a02df
2c89359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a02df
2c89359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a02df
 
 
ed2f92c
86a02df
 
7be2dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a02df
 
 
7be2dce
 
 
 
86a02df
 
7be2dce
 
 
 
 
 
 
 
 
 
 
86a02df
 
7be2dce
86a02df
 
7be2dce
 
 
 
 
 
ed2f92c
 
 
 
 
 
 
86a02df
5f22d1e
ed2f92c
 
86a02df
2c89359
ed2f92c
86a02df
be02445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
import io
import PIL.Image

def calculate_weight_diff(base_weight, chat_weight):
    return torch.abs(base_weight - chat_weight).mean().item()

def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
    layer_diffs = []
    layers = zip(base_model.model.layers, chat_model.model.layers)
    
    if load_one_at_a_time:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

            base_layer, chat_layer = None, None
            del base_layer, chat_layer
    else:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

    return layer_diffs

def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
    num_layers = len(layer_diffs)
    num_components = len(layer_diffs[0])
    
    # Dynamically adjust figure size based on number of layers
    height = max(8, num_layers / 8)  # Minimum height of 8, scales up for more layers
    width = max(24, num_components * 3)  # Minimum width of 24, scales with components
    
    # Create figure with subplots arranged in 2 rows if there are many components
    if num_components > 6:
        nrows = 2
        ncols = (num_components + 1) // 2
        fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
        axs = axs.flatten()
    else:
        nrows = 1
        ncols = num_components
        fig, axs = plt.subplots(1, num_components, figsize=(width, height))

    fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)

    # Adjust font sizes based on number of layers
    tick_font_size = max(6, min(10, 300 / num_layers))
    annot_font_size = max(6, min(10, 200 / num_layers))

    for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
        component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
        sns.heatmap(component_diffs, 
                    annot=True, 
                    fmt=".9f", 
                    cmap="YlGnBu", 
                    ax=axs[i], 
                    cbar=False,
                    annot_kws={'size': annot_font_size})
        
        axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
        axs[i].set_xlabel("Difference", fontsize=tick_font_size)
        axs[i].set_ylabel("Layer", fontsize=tick_font_size)
        axs[i].set_xticks([])
        axs[i].set_yticks(range(num_layers))
        axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
        axs[i].invert_yaxis()

    # Remove any empty subplots if using 2 rows
    if num_components > 6:
        for j in range(i + 1, len(axs)):
            fig.delaxes(axs[j])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to prevent overlap
    
    # Convert plot to image
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
    buf.seek(0)
    plt.close(fig)  # Close the figure to free memory
    return PIL.Image.open(buf)

def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
    chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)

    layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
    return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)

if __name__ == "__main__":
    iface = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Textbox(label="Base Model Name", lines=2),
            gr.Textbox(label="Chat Model Name", lines=2),
            gr.Textbox(label="Hugging Face Token", type="password", lines=2),
            gr.Checkbox(label="Load one layer at a time")
        ],
        outputs=gr.Image(type="pil", label="Weight Differences Visualization"),
        title="Model Weight Difference Visualizer",
        cache_examples=False
    )
    
    iface.launch(share=False, server_port=7860)