File size: 5,019 Bytes
e977050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import os
from PIL import Image
import numpy as np
from ipycanvas import Canvas
import cv2

from visualize_attention_src.utils import get_image

exp_dir = "saved_attention_map_results"

style_name = "line_art"
src_name = "cat"
tgt_name = "dog"

steps = ["20"]
seed = "4"
saved_dtype = "tensor"


attn_map_raws = []
for step in steps:
    attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}"  # new

    if saved_dtype == 'uint8':
        attn_map_name = attn_map_name_wo_ext + '_uint8.npy'
        attn_map_path = os.path.join(exp_dir, attn_map_name)
        attn_map_raws.append(np.load(attn_map_path, allow_pickle=True))

    else:
        attn_map_name = attn_map_name_wo_ext + '.pt'
        attn_map_path = os.path.join(exp_dir, attn_map_name)
        attn_map_raws.append(torch.load(attn_map_path))
        print(attn_map_path)

    attn_map_path = os.path.join(exp_dir, attn_map_name)

    print(f"{step} is on memory")

keys = [key for key in attn_map_raws[0].keys()]


print(len(keys))
key = keys[0]

########################
tgt_idx = 3 # indicating the location of generated images.

attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png"

attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name)
print(attn_map_paired_rgb_grid_path)
attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path)

attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10)
attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10)


h, w = 256, 256
num_of_grid = 64

plus_50 = 0

# key_idx_list = [0,2,4,6,8,10]
key_idx_list = [6, 28]
# (108 -> 0, 109 -> 1, ... , 140 -> 32)
# if Swapping Attentio nin (108, 140) layer , use key_idx_list = [6, 28].
# 6==early upblock, 28==late upblock

saved_attention_map_idx = [0]

source_image = attn_map_src_img
target_image = attn_map_tgt_img

# resize
source_image = source_image.resize((h, w))
target_image = target_image.resize((h, w))

# convert to numpy array
source_image = np.array(source_image)
target_image = np.array(target_image)

canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True)
canvas.put_image_data(source_image, w * 3, 0)
canvas.put_image_data(target_image, 0, 0)

canvas.put_image_data(source_image, w * 3, h)
canvas.put_image_data(target_image, 0, h)

# Display the canvas
# display(canvas)


def save_to_file(*args, **kwargs):
    canvas.to_file("my_file1.png")


# Listen to changes on the ``image_data`` trait and call ``save_to_file`` when it changes.
canvas.observe(save_to_file, "image_data")


def on_click(x, y):
    cnt = 0
    canvas.put_image_data(target_image, 0, 0)

    print(x, y)
    # draw a point
    canvas.fill_style = 'red'
    canvas.fill_circle(x, y, 4)

    for step_i, step in enumerate(range(len(saved_attention_map_idx))):

        attn_map_raw = attn_map_raws[step_i]

        for key_i, key_idx in enumerate(key_idx_list):
            key = keys[key_idx]

            num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5))

            # normalize x,y
            grid_x_idx = int(x / (w / num_of_grid))
            grid_y_idx = int(y / (h / num_of_grid))

            print(grid_x_idx, grid_y_idx)

            grid_idx = grid_x_idx + grid_y_idx * num_of_grid

            attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :]

            attn_map = attn_map.sum(dim=0)

            attn_map = attn_map.reshape(num_of_grid, num_of_grid)

            # process attn_map to pil
            attn_map = attn_map.detach().cpu().numpy()
            # attn_map = attn_map / attn_map.max()
            # normalized_attn_map = attn_map
            normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
            normalized_attn_map = 1.0 - normalized_attn_map

            heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET)
            heatmap = cv2.resize(heatmap, (w, h))

            attn_map = normalized_attn_map * 255

            attn_map = attn_map.astype(np.uint8)

            attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB)
            # attn_map = cv2.cvtColor(attn_map, cv2.COLORMAP_JET)
            attn_map = cv2.resize(attn_map, (w, h))

            # draw attn_map
            canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i)
            # canvas.put_image_data(attn_map, w , h*key_i)

            # blend attn_map and target image
            alpha = 0.85
            blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0)

            # draw blended image
            canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i)

    cnt += 1

    # Attach the event handler to the canvas


canvas.on_mouse_down(on_click)