Spaces:
Sleeping
Sleeping
Update
Browse files
app.py
CHANGED
@@ -121,7 +121,7 @@ def visualize_attention(
|
|
121 |
attentions_for_rollout = []
|
122 |
for layer_name, attn_map in attention_maps.items():
|
123 |
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
|
124 |
-
attn_map = attn_map[0] # Remove batch dimension
|
125 |
|
126 |
attentions_for_rollout.append(attn_map)
|
127 |
|
@@ -148,7 +148,7 @@ def visualize_attention(
|
|
148 |
# Interpolate to match image size
|
149 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
150 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
151 |
-
attn_map = attn_map.squeeze().cpu().numpy() # Move to CPU
|
152 |
|
153 |
# Normalize attention map
|
154 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
@@ -177,8 +177,8 @@ def visualize_attention(
|
|
177 |
visualizations.append(vis_image)
|
178 |
plt.close(fig)
|
179 |
|
180 |
-
# Ensure tensors are on CPU before converting to numpy
|
181 |
-
attentions_for_rollout = [attn.cpu() for attn in attentions_for_rollout]
|
182 |
|
183 |
# Calculate rollout
|
184 |
rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
|
|
|
121 |
attentions_for_rollout = []
|
122 |
for layer_name, attn_map in attention_maps.items():
|
123 |
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
|
124 |
+
attn_map = attn_map[0].detach() # Remove batch dimension and detach
|
125 |
|
126 |
attentions_for_rollout.append(attn_map)
|
127 |
|
|
|
148 |
# Interpolate to match image size
|
149 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
150 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
151 |
+
attn_map = attn_map.squeeze().cpu().detach().numpy() # Move to CPU, detach, and convert to numpy
|
152 |
|
153 |
# Normalize attention map
|
154 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
|
|
177 |
visualizations.append(vis_image)
|
178 |
plt.close(fig)
|
179 |
|
180 |
+
# Ensure tensors are on CPU and detached before converting to numpy
|
181 |
+
attentions_for_rollout = [attn.cpu().detach() for attn in attentions_for_rollout]
|
182 |
|
183 |
# Calculate rollout
|
184 |
rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
|