Steelskull commited on
Commit
7be2dce
·
verified ·
1 Parent(s): be02445

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -51,22 +51,52 @@ def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
51
  def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
52
  num_layers = len(layer_diffs)
53
  num_components = len(layer_diffs[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
56
  fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
57
 
 
 
 
 
58
  for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
59
  component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
60
- sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
61
- axs[i].set_title(component)
62
- axs[i].set_xlabel("Difference")
63
- axs[i].set_ylabel("Layer")
 
 
 
 
 
 
 
64
  axs[i].set_xticks([])
65
  axs[i].set_yticks(range(num_layers))
66
- axs[i].set_yticklabels(range(num_layers))
67
  axs[i].invert_yaxis()
68
 
69
- plt.tight_layout()
 
 
 
 
 
70
 
71
  # Convert plot to image
72
  buf = io.BytesIO()
@@ -76,7 +106,6 @@ def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
76
  return PIL.Image.open(buf)
77
 
78
  def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
79
- # Update to use 'token' instead of 'use_auth_token' to handle deprecation warning
80
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
81
  chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
82
 
 
51
  def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
52
  num_layers = len(layer_diffs)
53
  num_components = len(layer_diffs[0])
54
+
55
+ # Dynamically adjust figure size based on number of layers
56
+ height = max(8, num_layers / 8) # Minimum height of 8, scales up for more layers
57
+ width = max(24, num_components * 3) # Minimum width of 24, scales with components
58
+
59
+ # Create figure with subplots arranged in 2 rows if there are many components
60
+ if num_components > 6:
61
+ nrows = 2
62
+ ncols = (num_components + 1) // 2
63
+ fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
64
+ axs = axs.flatten()
65
+ else:
66
+ nrows = 1
67
+ ncols = num_components
68
+ fig, axs = plt.subplots(1, num_components, figsize=(width, height))
69
 
 
70
  fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
71
 
72
+ # Adjust font sizes based on number of layers
73
+ tick_font_size = max(6, min(10, 300 / num_layers))
74
+ annot_font_size = max(6, min(10, 200 / num_layers))
75
+
76
  for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
77
  component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
78
+ sns.heatmap(component_diffs,
79
+ annot=True,
80
+ fmt=".9f",
81
+ cmap="YlGnBu",
82
+ ax=axs[i],
83
+ cbar=False,
84
+ annot_kws={'size': annot_font_size})
85
+
86
+ axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
87
+ axs[i].set_xlabel("Difference", fontsize=tick_font_size)
88
+ axs[i].set_ylabel("Layer", fontsize=tick_font_size)
89
  axs[i].set_xticks([])
90
  axs[i].set_yticks(range(num_layers))
91
+ axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
92
  axs[i].invert_yaxis()
93
 
94
+ # Remove any empty subplots if using 2 rows
95
+ if num_components > 6:
96
+ for j in range(i + 1, len(axs)):
97
+ fig.delaxes(axs[j])
98
+
99
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent overlap
100
 
101
  # Convert plot to image
102
  buf = io.BytesIO()
 
106
  return PIL.Image.open(buf)
107
 
108
  def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
 
109
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
110
  chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
111