Steelskull commited on
Commit
86a02df
·
verified ·
1 Parent(s): 34a4ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -72
app.py CHANGED
@@ -1,72 +1,72 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- from tqdm import tqdm
6
- import gradio as gr
7
-
8
- def calculate_weight_diff(base_weight, chat_weight):
9
- return torch.abs(base_weight - chat_weight).mean().item()
10
-
11
- def calculate_layer_diffs(base_model, chat_model):
12
- layer_diffs = []
13
- for base_layer, chat_layer in tqdm(zip(base_model.model.layers, chat_model.model.layers), total=len(base_model.model.layers)):
14
- layer_diff = {
15
- 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
16
- 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
17
- 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
18
- 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
19
- 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
20
- 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
21
- 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
22
- 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
23
- 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
24
- }
25
- layer_diffs.append(layer_diff)
26
-
27
- base_layer, chat_layer = None, None
28
- del base_layer, chat_layer
29
-
30
- return layer_diffs
31
-
32
- def visualize_layer_diffs(layer_diffs):
33
- num_layers = len(layer_diffs)
34
- num_components = len(layer_diffs[0])
35
-
36
- fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
37
- fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
38
-
39
- for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
40
- component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
41
- sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
42
- axs[i].set_title(component)
43
- axs[i].set_xlabel("Difference")
44
- axs[i].set_ylabel("Layer")
45
- axs[i].set_xticks([])
46
- axs[i].set_yticks(range(num_layers))
47
- axs[i].set_yticklabels(range(num_layers))
48
- axs[i].invert_yaxis()
49
-
50
- plt.tight_layout()
51
- return fig
52
-
53
- def gradio_interface(base_model_name, chat_model_name):
54
- base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
55
- chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
56
-
57
- layer_diffs = calculate_layer_diffs(base_model, chat_model)
58
- fig = visualize_layer_diffs(layer_diffs)
59
-
60
- return fig
61
-
62
- iface = gr.Interface(
63
- fn=gradio_interface,
64
- inputs=[
65
- gr.inputs.Textbox(lines=2, placeholder="Enter base model name"),
66
- gr.inputs.Textbox(lines=2, placeholder="Enter chat model name")
67
- ],
68
- outputs="image",
69
- title="Model Weight Difference Visualizer"
70
- )
71
-
72
- iface.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from tqdm import tqdm
6
+ import gradio as gr
7
+
8
+ def calculate_weight_diff(base_weight, chat_weight):
9
+ return torch.abs(base_weight - chat_weight).mean().item()
10
+
11
+ def calculate_layer_diffs(base_model, chat_model):
12
+ layer_diffs = []
13
+ for base_layer, chat_layer in tqdm(zip(base_model.model.layers, chat_model.model.layers), total=len(base_model.model.layers)):
14
+ layer_diff = {
15
+ 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
16
+ 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
17
+ 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
18
+ 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
19
+ 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
20
+ 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
21
+ 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
22
+ 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
23
+ 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
24
+ }
25
+ layer_diffs.append(layer_diff)
26
+
27
+ base_layer, chat_layer = None, None
28
+ del base_layer, chat_layer
29
+
30
+ return layer_diffs
31
+
32
+ def visualize_layer_diffs(layer_diffs):
33
+ num_layers = len(layer_diffs)
34
+ num_components = len(layer_diffs[0])
35
+
36
+ fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
37
+ fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
38
+
39
+ for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
40
+ component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
41
+ sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
42
+ axs[i].set_title(component)
43
+ axs[i].set_xlabel("Difference")
44
+ axs[i].set_ylabel("Layer")
45
+ axs[i].set_xticks([])
46
+ axs[i].set_yticks(range(num_layers))
47
+ axs[i].set_yticklabels(range(num_layers))
48
+ axs[i].invert_yaxis()
49
+
50
+ plt.tight_layout()
51
+ return fig
52
+
53
+ def gradio_interface(base_model_name, chat_model_name):
54
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
55
+ chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
56
+
57
+ layer_diffs = calculate_layer_diffs(base_model, chat_model)
58
+ fig = visualize_layer_diffs(layer_diffs)
59
+
60
+ return fig
61
+
62
+ iface = gr.Interface(
63
+ fn=gradio_interface,
64
+ inputs=[
65
+ gr.Textbox(lines=2, placeholder="Enter base model name"),
66
+ gr.Textbox(lines=2, placeholder="Enter chat model name")
67
+ ],
68
+ outputs="image",
69
+ title="Model Weight Difference Visualizer"
70
+ )
71
+
72
+ iface.launch()