import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2
from sklearn.decomposition import PCA
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import os
def display_attention_maps(
attention_maps = attention_maps.reshape(-1, num_heads, attention_maps.size(-2), attention_maps.size(-1))
num_samples = len(attention_maps) // 2
attention_type = 'cross' if is_cross else 'self'
for i, attention_map in enumerate(attention_maps):
if is_query:
attention_type = f'{attention_type}_queries'
elif is_key:
attention_type = f'{attention_type}_keys'
cond = 'uncond' if i < num_samples else 'cond'
i = i % num_samples
cur_dir_name = f'{dir_name}/{resolution}/{attention_type}/{layer}/{cond}/{i}'
os.makedirs(cur_dir_name, exist_ok=True)
if is_cross and not is_query:
fig = show_cross_attention(attention_map, tokenizer, prompts[i % num_samples])
fig = show_self_attention(attention_map)
if points is not None:
point_dir_name = f'{cur_dir_name}/points'
os.makedirs(point_dir_name, exist_ok=True)
for j, point in enumerate(points):
specific_point_dir_name = f'{point_dir_name}/{j}'
os.makedirs(specific_point_dir_name, exist_ok=True)
point_path = f'{specific_point_dir_name}/{step}.png'
point_fig = show_individual_self_attention(attention_map, point, image_path=image_path)
def text_under_image(image: np.ndarray, text: str, text_color: tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
text_size = cv2.getTextSize(text, font, 1, 2)[0]
lines = text.splitlines()
img = np.ones((h + offset + (text_size[1] + 2) * len(lines) - 2, w, c), dtype=np.uint8) * 255
img[:h, :w] = image
for i, line in enumerate(lines):
text_size = cv2.getTextSize(line, font, 1, 2)[0]
text_x, text_y = ((w - text_size[0]) // 2, h + offset + i * (text_size[1] + 2))
cv2.putText(img, line, (text_x, text_y), font, 1, text_color, 2)
return img
def view_images(images, num_rows=1, offset_ratio=0.02):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
return Image.fromarray(image_)
def show_cross_attention(attention_maps, tokenizer, prompt, k_norms=None, v_norms=None):
attention_maps = attention_maps.mean(dim=0)
res = int(attention_maps.size(-2) ** 0.5)
attention_maps = attention_maps.reshape(res, res, -1)
tokens = tokenizer.encode(prompt)
decoder = tokenizer.decode
if k_norms is not None:
k_norms = k_norms.round(decimals=1)
if v_norms is not None:
v_norms = v_norms.round(decimals=1)
images = []
for i in range(len(tokens) + 5):
image = attention_maps[:, :, i]
image = 255 * image / image.max()
image = image.unsqueeze(-1).expand(*image.shape, 3)
image = image.detach().cpu().numpy().astype(np.uint8)
image = np.array(Image.fromarray(image).resize((256, 256)))
token = tokens[i] if i < len(tokens) else tokens[-1]
text = decoder(int(token))
if k_norms is not None and v_norms is not None:
text += f'\n{k_norms[i]}\n{v_norms[i]})'
image = text_under_image(image, text)
return view_images(np.stack(images, axis=0))
def show_queries_keys(queries, keys, colors, labels): # [h ni d]
num_queries = [query.size(1) for query in queries]
num_keys = [key.size(1) for key in keys]
h, _, d = queries[0].shape
data =*queries, *keys), dim=1) # h n d
data = data.permute(1, 0, 2) # n h d
data = data.reshape(-1, h * d).detach().cpu().numpy()
pca = PCA(n_components=2)
data = pca.fit_transform(data) # n 2
query_indices = np.array(num_queries).cumsum()
total_num_queries = query_indices[-1]
queries = np.split(data[:total_num_queries], query_indices[:-1])
if len(num_keys) == 0:
keys = [None, ] * len(labels)
key_indices = np.array(num_keys).cumsum()
keys = np.split(data[total_num_queries:], key_indices[:-1])
fig, ax = plt.subplots()
marker_size = plt.rcParams['lines.markersize'] ** 2
query_size = int(1.25 * marker_size)
key_size = int(2 * marker_size)
for query, key, color, label in zip(queries, keys, colors, labels):
print(f'# queries of {label}', query.shape[0])
ax.scatter(query[:, 0], query[:, 1], s=query_size, color=color, marker='o', label=f'"{label}" queries')
if key is None:
print(f'# keys of {label}', key.shape[0])
keys_label = f'"{label}" key'
if key.shape[0] > 1:
keys_label += 's'
ax.scatter(key[:, 0], key[:, 1], s=key_size, color=color, marker='x', label=keys_label)
#ax.set_title('Scatter Plot with Circles and Crosses')
return fig
def show_self_attention(attention_maps): # h n m
attention_maps = attention_maps.transpose(0, 1).flatten(start_dim=1).detach().cpu().numpy()
pca = PCA(n_components=3)
pca_img = pca.fit_transform(attention_maps) # N X 3
h = w = int(pca_img.shape[0] ** 0.5)
pca_img = pca_img.reshape(h, w, 3)
pca_img_min = pca_img.min(axis=(0, 1))
pca_img_max = pca_img.max(axis=(0, 1))
pca_img = (pca_img - pca_img_min) / (pca_img_max - pca_img_min)
pca_img = Image.fromarray((pca_img * 255).astype(np.uint8))
pca_img = transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST)(pca_img)
return pca_img
def draw_box(pil_img, bboxes, colors=None, width=5):
draw = ImageDraw.Draw(pil_img)
#font = ImageFont.truetype('./FreeMono.ttf', 25)
w, h = pil_img.size
colors = ['red'] * len(bboxes) if colors is None else colors
for obj_bbox, color in zip(bboxes, colors):
x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
draw.rectangle([int(x_0 * w), int(y_0 * h), int(x_1 * w), int(y_1 * h)], outline=color, width=width)
return pil_img
def show_individual_self_attention(attn, point, image_path=None):
resolution = int(attn.size(-1) ** 0.5)
attn = attn.mean(dim=0).reshape(resolution, resolution, resolution, resolution)
attn = attn[round(point[1] * resolution), round(point[0] * resolution)]
attn = (attn - attn.min()) / (attn.max() - attn.min())
image = None if image_path is None else'RGB')
image = show_image_relevance(attn, image=image)
return Image.fromarray(image)
def show_image_relevance(image_relevance, image: Image.Image = None, relevnace_res=16):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
img = img.resize((relevnace_res ** 2, relevnace_res ** 2))
img = np.array(img)
img = (img - img.min()) / (img.max() - img.min())
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
image_relevance = image_relevance.cpu() # send it back to cpu
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
vis = image_relevance if image is None else show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis