Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import math | |
import utils | |
from . import parse | |
save_ind = 0 | |
def visualize(image, title, colorbar=False, show_plot=True, **kwargs): | |
plt.title(title) | |
plt.imshow(image, **kwargs) | |
if colorbar: | |
plt.colorbar() | |
if show_plot: | |
plt.show() | |
def visualize_arrays(image_title_pairs, colorbar_index=-1, show_plot=True, figsize=None, **kwargs): | |
if figsize is not None: | |
plt.figure(figsize=figsize) | |
num_subplots = len(image_title_pairs) | |
for idx, image_title_pair in enumerate(image_title_pairs): | |
plt.subplot(1, num_subplots, idx+1) | |
if isinstance(image_title_pair, (list, tuple)): | |
image, title = image_title_pair | |
else: | |
image, title = image_title_pair, None | |
if title is not None: | |
plt.title(title) | |
plt.imshow(image, **kwargs) | |
if idx == colorbar_index: | |
plt.colorbar() | |
if show_plot: | |
plt.show() | |
def visualize_masked_latents(latents_all, masked_latents, timestep_T=False, timestep_0=True): | |
if timestep_T: | |
# from T to 0 | |
latent_idx = 0 | |
plt.subplot(1, 2, 1) | |
plt.title("latents_all (t=T)") | |
plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray") | |
plt.subplot(1, 2, 2) | |
plt.title("mask latents (t=T)") | |
plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray") | |
plt.show() | |
if timestep_0: | |
latent_idx = -1 | |
plt.subplot(1, 2, 1) | |
plt.title("latents_all (t=0)") | |
plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray") | |
plt.subplot(1, 2, 2) | |
plt.title("mask latents (t=0)") | |
plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray") | |
plt.show() | |
# This function has not been adapted to new `saved_attn`. | |
def visualize_attn(token_map, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False): | |
""" | |
Visualize cross attention: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token | |
cross_attention_probs_tensors: | |
One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors` | |
stage_id: index of downsampling/mid/upsaming block | |
block_id: index of the transformer block | |
""" | |
plt.figure(figsize=(20, 8)) | |
for token_id in range(len(token_map)): | |
token = token_map[token_id] | |
plt.subplot(1, len(token_map), token_id + 1) | |
plt.title(token) | |
attn = cross_attention_probs_tensors[stage_id][visualize_step_start:].mean(dim=0)[block_id] | |
if not input_ca_has_condition_only: | |
assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" | |
attn = attn[1] | |
else: | |
assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" | |
attn = attn[0] | |
attn = attn.mean(dim=0)[:, token_id] | |
H = W = int(math.sqrt(attn.shape[0])) | |
attn = attn.reshape((H, W)) | |
plt.imshow(attn.cpu().numpy()) | |
plt.show() | |
# This function has not been adapted to new `saved_attn`. | |
def visualize_across_timesteps(token_id, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False): | |
""" | |
Visualize cross attention for one token, across timesteps: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token | |
cross_attention_probs_tensors: | |
One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors` | |
stage_id: index of downsampling/mid/upsaming block | |
block_id: index of the transformer block | |
`visualize_step_start` is not used. We visualize all timesteps. | |
""" | |
plt.figure(figsize=(50, 8)) | |
attn_stage = cross_attention_probs_tensors[stage_id] | |
num_inference_steps = attn_stage.shape[0] | |
for t in range(num_inference_steps): | |
plt.subplot(1, num_inference_steps, t + 1) | |
plt.title(f"t: {t}") | |
attn = attn_stage[t][block_id] | |
if not input_ca_has_condition_only: | |
assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" | |
attn = attn[1] | |
else: | |
assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" | |
attn = attn[0] | |
attn = attn.mean(dim=0)[:, token_id] | |
H = W = int(math.sqrt(attn.shape[0])) | |
attn = attn.reshape((H, W)) | |
plt.imshow(attn.cpu().numpy()) | |
plt.axis("off") | |
plt.tight_layout() | |
plt.show() | |
def visualize_bboxes(bboxes, H, W): | |
num_boxes = len(bboxes) | |
for ind, bbox in enumerate(bboxes): | |
plt.subplot(1, num_boxes, ind + 1) | |
fg_mask = utils.proportion_to_mask(bbox, H, W) | |
plt.title(f"transformed bbox ({ind})") | |
plt.imshow(fg_mask.cpu().numpy()) | |
plt.show() | |
def display(image, save_prefix="", ind=None): | |
global save_ind | |
if save_prefix != "": | |
save_prefix = save_prefix + "_" | |
ind = f"{ind}_" if ind is not None else "" | |
path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png" | |
print(f"Saved to {path}") | |
image.save(path) | |
save_ind = save_ind + 1 | |