Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -51,22 +51,52 @@ def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
|
|
51 |
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
|
52 |
num_layers = len(layer_diffs)
|
53 |
num_components = len(layer_diffs[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
|
56 |
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
|
57 |
|
|
|
|
|
|
|
|
|
58 |
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
|
59 |
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
|
60 |
-
sns.heatmap(component_diffs,
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
axs[i].set_xticks([])
|
65 |
axs[i].set_yticks(range(num_layers))
|
66 |
-
axs[i].set_yticklabels(range(num_layers))
|
67 |
axs[i].invert_yaxis()
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Convert plot to image
|
72 |
buf = io.BytesIO()
|
@@ -76,7 +106,6 @@ def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
|
|
76 |
return PIL.Image.open(buf)
|
77 |
|
78 |
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
|
79 |
-
# Update to use 'token' instead of 'use_auth_token' to handle deprecation warning
|
80 |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
81 |
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
82 |
|
|
|
51 |
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
|
52 |
num_layers = len(layer_diffs)
|
53 |
num_components = len(layer_diffs[0])
|
54 |
+
|
55 |
+
# Dynamically adjust figure size based on number of layers
|
56 |
+
height = max(8, num_layers / 8) # Minimum height of 8, scales up for more layers
|
57 |
+
width = max(24, num_components * 3) # Minimum width of 24, scales with components
|
58 |
+
|
59 |
+
# Create figure with subplots arranged in 2 rows if there are many components
|
60 |
+
if num_components > 6:
|
61 |
+
nrows = 2
|
62 |
+
ncols = (num_components + 1) // 2
|
63 |
+
fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
|
64 |
+
axs = axs.flatten()
|
65 |
+
else:
|
66 |
+
nrows = 1
|
67 |
+
ncols = num_components
|
68 |
+
fig, axs = plt.subplots(1, num_components, figsize=(width, height))
|
69 |
|
|
|
70 |
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
|
71 |
|
72 |
+
# Adjust font sizes based on number of layers
|
73 |
+
tick_font_size = max(6, min(10, 300 / num_layers))
|
74 |
+
annot_font_size = max(6, min(10, 200 / num_layers))
|
75 |
+
|
76 |
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
|
77 |
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
|
78 |
+
sns.heatmap(component_diffs,
|
79 |
+
annot=True,
|
80 |
+
fmt=".9f",
|
81 |
+
cmap="YlGnBu",
|
82 |
+
ax=axs[i],
|
83 |
+
cbar=False,
|
84 |
+
annot_kws={'size': annot_font_size})
|
85 |
+
|
86 |
+
axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
|
87 |
+
axs[i].set_xlabel("Difference", fontsize=tick_font_size)
|
88 |
+
axs[i].set_ylabel("Layer", fontsize=tick_font_size)
|
89 |
axs[i].set_xticks([])
|
90 |
axs[i].set_yticks(range(num_layers))
|
91 |
+
axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
|
92 |
axs[i].invert_yaxis()
|
93 |
|
94 |
+
# Remove any empty subplots if using 2 rows
|
95 |
+
if num_components > 6:
|
96 |
+
for j in range(i + 1, len(axs)):
|
97 |
+
fig.delaxes(axs[j])
|
98 |
+
|
99 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent overlap
|
100 |
|
101 |
# Convert plot to image
|
102 |
buf = io.BytesIO()
|
|
|
106 |
return PIL.Image.open(buf)
|
107 |
|
108 |
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
|
|
|
109 |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
110 |
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
111 |
|