import enum import torch import torchvision import numpy as np from ..Misc import Logger as log from ..Setting import Config import matplotlib.pyplot as plt import matplotlib # To avoid plt.imshow crash matplotlib.use("Agg") class CAttnProcChoice(enum.Enum): INVALID = -1 BASIC = 0 def plot_activations(cross_attn, prompt, plot_with_trailings=False): num_frames = cross_attn.shape[0] cross_attn = cross_attn.cpu() for i in range(num_frames): filename = "/tmp/out.{:04d}.jpg".format(i) plot_activation(cross_attn[i], prompt, filename, plot_with_trailings) def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False): splitted_prompt = prompt.split(" ") n = len(splitted_prompt) start = 0 arrs = [] if plot_with_trailings: for j in range(5): arr = [] for i in range(start, start + n): cross_attn_sliced = cross_attn[..., i + 1] arr.append(cross_attn_sliced.T) start += n arr = np.hstack(arr) arrs.append(arr) arrs = np.vstack(arrs).T else: arr = [] for i in range(start, start + n): print(i) cross_attn_sliced = cross_attn[..., i + 1] arr.append(cross_attn_sliced) arrs = np.hstack(arr).astype(np.float32) plt.clf() v_min = arrs.min() v_max = arrs.max() n_min = 0.0 n_max = 1 arrs = (arrs - v_min) / (v_max - v_min) arrs = (arrs * (n_max - n_min)) + n_min plt.imshow(arrs, cmap="jet") plt.title(prompt) plt.colorbar(orientation="horizontal", pad=0.2) if filepath: plt.savefig(filepath) log.info(f"Saved [{filepath}]") else: plt.show() def get_cross_attn( unet, resolution=32, target_size=64, ): """To get the cross attention map softmax(QK^T) from Unet. Args: unet (UNet2DConditionModel): unet resolution (int): the cross attention map with specific resolution. It only supports 64, 32, 16, and 8 target_size (int): the target resolution for resizing the cross attention map Returns: (torch.tensor): a tensor with shape (target_size, target_size, 77) """ attns = [] check = [8, 16, 32, 64] if resolution not in check: raise ValueError( "The cross attention resolution only support 8x8, 16x16, 32x32, and 64x64. " "The given resolution {}x{} is not in the list. Abort.".format( resolution, resolution ) ) for name, module in unet.named_modules(): module_name = type(module).__name__ # NOTE: attn2 is for cross-attention while attn1 is self-attention dim = resolution * resolution if not hasattr(module, "processor"): continue if hasattr(module.processor, "cross_attention_map"): attn = module.processor.cross_attention_map[None, ...] attns.append(attn) if not attns: print("Err: Quried attns size [{}]".format(len(attns))) return attns = torch.cat(attns, dim=0) attns = torch.sum(attns, dim=0) # resized = torch.zeros([target_size, target_size, 77]) # f = torchvision.transforms.Resize(size=(64, 64)) # dim = attns.shape[1] # print(attns.shape) # for i in range(77): # attn_slice = attns[..., i].view(1, dim, dim) # resized[..., i] = f(attn_slice)[0] return attns def get_avg_cross_attn(unet, resolutions, resize): """To get the average cross attention map across its resolutions. Args: unet (UNet2DConditionModel): unet resolution (list): a list of specific resolution. It only supports 64, 32, 16, and 8 target_size (int): the target resolution for resizing the cross attention map Returns: (torch.tensor): a tensor with shape (target_size, target_size, 77) """ cross_attns = [] for resolution in resolutions: try: cross_attns.append(get_cross_attn(unet, resolution, resize)) except: log.warn(f"No cross-attention map with resolution [{resolution}]") if cross_attns: cross_attns = torch.stack(cross_attns).mean(0) return cross_attns def save_cross_attn(unet): """TODO: to save cross attn""" for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: folder = "/tmp" filepath = os.path.join(folder, name + ".pt") torch.save(module.attn, filepath) print(filepath) def use_dd(unet, use=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.processor.use_dd = use def use_dd_temporal(unet, use=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.processor.use_dd_temporal = use def get_loss(unet): loss = 0 total = 0 for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: loss += module.processor.loss total += 1 return loss / total def get_params(unet): parameters = [] for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: parameters.append(module.processor.parameters) return parameters