Spaces:
Sleeping
Sleeping
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +3 -0
- app.py +397 -0
- color_palette/._dataset_processing.py +0 -0
- color_palette/__pycache__/cnn_dataset.cpython-39.pyc +0 -0
- color_palette/__pycache__/config.cpython-38.pyc +0 -0
- color_palette/__pycache__/config.cpython-39.pyc +0 -0
- color_palette/__pycache__/dataset.cpython-38.pyc +0 -0
- color_palette/__pycache__/dataset.cpython-39.pyc +0 -0
- color_palette/__pycache__/dataset_processing.cpython-39.pyc +0 -0
- color_palette/__pycache__/train_CNN.cpython-39.pyc +0 -0
- color_palette/__pycache__/utils.cpython-38.pyc +0 -0
- color_palette/__pycache__/utils.cpython-39.pyc +0 -0
- color_palette/all_one_hot_LR/test_gt.npy +3 -0
- color_palette/all_one_hot_LR/test_preds.npy +3 -0
- color_palette/all_one_hot_LR/test_preds_graph.npy +3 -0
- color_palette/all_one_hot_LR/test_rgb_colors.npy +3 -0
- color_palette/all_one_hot_LR_sequential/new_palettes.npy +3 -0
- color_palette/all_one_hot_LR_sequential/original_palettes.npy +3 -0
- color_palette/all_one_hot_LR_sequential/test_gt.npy +3 -0
- color_palette/all_one_hot_LR_sequential/test_preds.npy +3 -0
- color_palette/app copy.py +326 -0
- color_palette/bash_scripts/training.sh +24 -0
- color_palette/cnn_dataset.py +104 -0
- color_palette/colorCNN.py +238 -0
- color_palette/config.py +29 -0
- color_palette/config/conf.yaml +15 -0
- color_palette/config/confCNN.yaml +16 -0
- color_palette/config/grid_search_conf_generator.py +24 -0
- color_palette/cube_num_one_hot_LR/test_gt.npy +3 -0
- color_palette/cube_num_one_hot_LR/test_preds.npy +3 -0
- color_palette/cube_num_one_hot_LR/test_preds_graph.npy +3 -0
- color_palette/cube_num_one_hot_LR/test_rgb_colors.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/new_palettes.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/new_palettes_purple.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/original_palettes.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/original_palettes_purple.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/test_gt.npy +3 -0
- color_palette/cube_num_one_hot_LR_sequential/test_preds.npy +3 -0
- color_palette/dataset.py +215 -0
- color_palette/dataset_processing.py +505 -0
- color_palette/deneme.png +0 -0
- color_palette/deneme.py +107 -0
- color_palette/denemeler.ipynb +0 -0
- color_palette/dist.png +0 -0
- color_palette/evaluate.py +188 -0
- color_palette/evaluate_CNN.py +180 -0
- color_palette/evaluate_classification.py +217 -0
- color_palette/evaluate_recommend.py +72 -0
- color_palette/model/CNN.py +209 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
color_palette/regressor/colorLoversData.mat filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/destijl_dataset
|
2 |
+
*.jpg
|
3 |
+
*.log
|
app.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from color_palette.regressor.config import config_to_use
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageFont, ImageDraw
|
4 |
+
from sklearn.linear_model import LinearRegression
|
5 |
+
from color_palette.model.GNN import ColorAttentionClassification
|
6 |
+
from color_palette.regressor.model import Color2CubeDataset
|
7 |
+
from color_palette.regressor.config import *
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from color_palette.dataset import GraphDestijlDataset
|
10 |
+
from color_palette.config import DataConfig
|
11 |
+
import random
|
12 |
+
import os
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch
|
15 |
+
import gradio as gr
|
16 |
+
|
17 |
+
config = DataConfig()
|
18 |
+
model_name = config.model_name
|
19 |
+
dataset_root = config.dataset
|
20 |
+
feature_size = config.feature_size
|
21 |
+
device = config.device
|
22 |
+
image_folder = "img_folder"
|
23 |
+
|
24 |
+
if not os.path.exists(image_folder):
|
25 |
+
os.mkdir(image_folder)
|
26 |
+
|
27 |
+
def train_regressor(train_loader):
|
28 |
+
X = []
|
29 |
+
y = []
|
30 |
+
for i, (input_data, target) in enumerate(train_loader):
|
31 |
+
input_data = np.squeeze(input_data)
|
32 |
+
target = np.squeeze(target)
|
33 |
+
X.append(input_data)
|
34 |
+
y.append(target)
|
35 |
+
|
36 |
+
X = np.stack(X, axis=0)
|
37 |
+
y = np.squeeze(np.stack(y, axis=0))
|
38 |
+
|
39 |
+
print("Before regressor train!\n")
|
40 |
+
|
41 |
+
reg = LinearRegression().fit(X, y)
|
42 |
+
|
43 |
+
return reg
|
44 |
+
|
45 |
+
model_weight_path = "models/" + model_name + "/weights/best.pth"
|
46 |
+
|
47 |
+
# palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
|
48 |
+
# original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
|
49 |
+
|
50 |
+
graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
|
51 |
+
model = ColorAttentionClassification(feature_size).to(device)
|
52 |
+
model.load_state_dict(torch.load(model_weight_path)["state_dict"])
|
53 |
+
|
54 |
+
dataset = Color2CubeDataset(config=config_to_use)
|
55 |
+
train_loader = DataLoader(dataset, batch_size=1, shuffle=False)
|
56 |
+
regressor = train_regressor(train_loader=train_loader)
|
57 |
+
|
58 |
+
palette_of_the_design = [[0, 0, 0] for i in range(5)]
|
59 |
+
all_node_colors = None
|
60 |
+
class Demo:
|
61 |
+
def __init__(self, graph_dataset):
|
62 |
+
self.dataset = graph_dataset
|
63 |
+
|
64 |
+
first_sample_idx = random.randint(0, len(self.dataset)-1)
|
65 |
+
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
|
66 |
+
global all_node_colors
|
67 |
+
all_node_colors = also_normal_values
|
68 |
+
self.same_indices = None
|
69 |
+
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
|
70 |
+
|
71 |
+
def demo_reset(self):
|
72 |
+
first_sample_idx = random.randint(0, len(self.dataset))
|
73 |
+
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
|
74 |
+
global all_node_colors
|
75 |
+
all_node_colors = also_normal_values
|
76 |
+
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
|
77 |
+
|
78 |
+
def generate_img_from_palette(self, palette, canvas_size=512, is_first=False):
|
79 |
+
palette = np.array(palette).astype('int')
|
80 |
+
rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette]
|
81 |
+
if is_first:
|
82 |
+
self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette)
|
83 |
+
else:
|
84 |
+
_, unique_colors, _ = self.return_all_same_colors(palette=palette)
|
85 |
+
|
86 |
+
# assign the current palette using global keyword
|
87 |
+
global palette_of_the_design
|
88 |
+
palette_of_the_design = unique_colors
|
89 |
+
|
90 |
+
# Set the background color and create an empty PIL Image to fill with shapes and text
|
91 |
+
image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg)
|
92 |
+
# Save background image
|
93 |
+
|
94 |
+
title = "Lorem Ipsum Dolor"
|
95 |
+
undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..."
|
96 |
+
|
97 |
+
draw = ImageDraw.Draw(image)
|
98 |
+
# Set settings for the fonts
|
99 |
+
font_title = ImageFont.truetype("Arial.ttf", 32)
|
100 |
+
title_width, title_height = draw.textsize(title, font=font_title)
|
101 |
+
title_x = (canvas_size - title_width) // 2
|
102 |
+
title_y = (canvas_size - title_height) // 2 - 100
|
103 |
+
|
104 |
+
font_undertitle = ImageFont.truetype("Arial.ttf", 15)
|
105 |
+
text_width, text_height = draw.textsize(undertitle, font=font_undertitle)
|
106 |
+
undertitle_x = (canvas_size - text_width) // 2
|
107 |
+
undertitle_y = (canvas_size - text_height) // 2 - 50
|
108 |
+
# Draw titles
|
109 |
+
draw.text((title_x, title_y), title, fill=rgb_text, font=font_title)
|
110 |
+
draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle)
|
111 |
+
|
112 |
+
# Draw the circle
|
113 |
+
rad = random.randint(30, 70)
|
114 |
+
x = random.randint(400, 512-(rad+10))
|
115 |
+
y = random.randint(10, title_y-(rad+10))
|
116 |
+
|
117 |
+
draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle)
|
118 |
+
|
119 |
+
# Draw the image
|
120 |
+
for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]):
|
121 |
+
x = 512-((j+1)*60)
|
122 |
+
y = 512-((j+1)*60)
|
123 |
+
|
124 |
+
if j == 0:
|
125 |
+
rad = 80
|
126 |
+
draw.rectangle((x, y, x+rad, y+rad), fill=color)
|
127 |
+
else:
|
128 |
+
rad = 40
|
129 |
+
draw.rectangle((x, y, x+rad, y+rad), fill=color)
|
130 |
+
|
131 |
+
image.save(os.path.join("deneme.png"))
|
132 |
+
|
133 |
+
def run_model(self, input_data, target_color, node_to_mask, updated_color):
|
134 |
+
|
135 |
+
global all_node_colors
|
136 |
+
palette = np.array([color.detach().numpy()*255 for color in all_node_colors]).astype('int')
|
137 |
+
same_indices_list, unique_colors, first_indices = self.return_all_same_colors(palette)
|
138 |
+
|
139 |
+
unique_colors = unique_colors/255
|
140 |
+
|
141 |
+
selected_color = torch.Tensor(updated_color)/255
|
142 |
+
map_node_to_mask = -1
|
143 |
+
print("same indices list")
|
144 |
+
print(self.same_indices)
|
145 |
+
print("node to mask")
|
146 |
+
print(node_to_mask)
|
147 |
+
for i, idxs in enumerate(self.same_indices):
|
148 |
+
if node_to_mask in idxs:
|
149 |
+
map_node_to_mask = i
|
150 |
+
print("map node to mask: ", i)
|
151 |
+
|
152 |
+
for i, indices in enumerate(self.same_indices):
|
153 |
+
|
154 |
+
if i == 0:
|
155 |
+
# update the color [0.15, 0.4908, 0.73]
|
156 |
+
cube_num_of_selected = self.rgb2cube(selected_color*255)
|
157 |
+
one_hot = np.zeros((64,))
|
158 |
+
one_hot[int(cube_num_of_selected)] = 1.0
|
159 |
+
node_to_recommend = map_node_to_mask
|
160 |
+
input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot)
|
161 |
+
unique_colors[map_node_to_mask] = selected_color
|
162 |
+
else:
|
163 |
+
if i == map_node_to_mask:
|
164 |
+
zeroth_bin = 0
|
165 |
+
indices = same_indices_list[zeroth_bin]
|
166 |
+
node_to_recommend = 0
|
167 |
+
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
|
168 |
+
node_to_mask = indices[0]
|
169 |
+
else:
|
170 |
+
node_to_recommend = i
|
171 |
+
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
|
172 |
+
node_to_mask = indices[0]
|
173 |
+
|
174 |
+
out = self.forward_pass(model, input_data) # input data has one-hot color features
|
175 |
+
if torch.is_tensor(node_to_mask):
|
176 |
+
node_to_mask = node_to_mask.item()
|
177 |
+
values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation
|
178 |
+
prediction = values_indices.detach().numpy()[2]
|
179 |
+
# construct a palette using unique RGB palette and one-hot representation of the prediction cube.
|
180 |
+
feature_vector = self.create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend)
|
181 |
+
# map cube to rgb color space using the regressor
|
182 |
+
recommendation = regressor.predict(feature_vector)[0]
|
183 |
+
# we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information.
|
184 |
+
# update the color in the palette and run the algorithm for rest of the palette.
|
185 |
+
# for that, first map the color to cube and convert to one_hot
|
186 |
+
input_data, unique_colors = self.update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend)
|
187 |
+
# recursively do this here.
|
188 |
+
# save the results.
|
189 |
+
return np.array(unique_colors*255).astype(int)
|
190 |
+
|
191 |
+
def rgb2cube(self, color):
|
192 |
+
intervals = np.arange(0, 256, 256//4)
|
193 |
+
cube_coordinates = []
|
194 |
+
for channel in color:
|
195 |
+
i = 0
|
196 |
+
for j, value in enumerate(intervals):
|
197 |
+
if value < channel:
|
198 |
+
i = j
|
199 |
+
cube_coordinates.append(i)
|
200 |
+
|
201 |
+
cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4
|
202 |
+
return cube_num
|
203 |
+
|
204 |
+
def cube2rgb(self, cube_num):
|
205 |
+
"""
|
206 |
+
Return the start of the ranges
|
207 |
+
"""
|
208 |
+
cube_num = int(cube_num)
|
209 |
+
intervals = np.arange(0, 256, 256//4)
|
210 |
+
coor2 = cube_num // 16
|
211 |
+
coor1 = (cube_num - coor2*4*4) // 4
|
212 |
+
coor0 = cube_num - coor2*4*4 - coor1*4
|
213 |
+
return [intervals[coor0], intervals[coor1], intervals[coor2]]
|
214 |
+
|
215 |
+
def return_all_same_colors(self, palette):
|
216 |
+
|
217 |
+
indices_list = [[],[],[],[],[]]
|
218 |
+
unique_colors, first_indices = np.unique(palette, axis=0, return_index=True)
|
219 |
+
|
220 |
+
unique_colors = np.array(unique_colors)
|
221 |
+
all_colors = np.array(palette)
|
222 |
+
|
223 |
+
for idx, color in enumerate(unique_colors):
|
224 |
+
for node_num, element in enumerate(all_colors):
|
225 |
+
if np.equal(color, element).all():
|
226 |
+
indices_list[idx].append(node_num)
|
227 |
+
|
228 |
+
# these palettes and indices also include the masked color
|
229 |
+
return indices_list, unique_colors, first_indices
|
230 |
+
|
231 |
+
def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs):
|
232 |
+
# convert prediction to one-hot vector
|
233 |
+
cube_num_of_the_changed_color = self.rgb2cube(recommendation*255)
|
234 |
+
one_hot = np.zeros((64,))
|
235 |
+
one_hot[int(cube_num_of_the_changed_color)] = 1.0
|
236 |
+
|
237 |
+
# update the feature vector accordingly for all the same colors
|
238 |
+
for idx in indices_list[idx_to_idxs]:
|
239 |
+
input_data.x[idx, 4:] = torch.Tensor(one_hot)
|
240 |
+
|
241 |
+
# update the unique color vector
|
242 |
+
unique_rgb_palette[idx_to_idxs] = recommendation
|
243 |
+
return input_data, unique_rgb_palette
|
244 |
+
|
245 |
+
|
246 |
+
def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask):
|
247 |
+
one_hot = np.zeros((64,))
|
248 |
+
one_hot[int(cube_num)] = 1.0
|
249 |
+
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
|
250 |
+
feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0)
|
251 |
+
return feature_vector.reshape(1, -1)
|
252 |
+
|
253 |
+
def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask):
|
254 |
+
one_hot = np.zeros((64,))
|
255 |
+
one_hot[int(cube_num)] = 1.0
|
256 |
+
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
|
257 |
+
new_input_data = []
|
258 |
+
for color in removed_palette:
|
259 |
+
color_cube_num = self.rgb2cube(color*255)
|
260 |
+
empty_arr = np.zeros((64,))
|
261 |
+
empty_arr[int(color_cube_num)] = 1.0
|
262 |
+
new_input_data.append(empty_arr)
|
263 |
+
|
264 |
+
feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0)
|
265 |
+
return feature_vector.reshape(1, -1)
|
266 |
+
|
267 |
+
def forward_pass(self, model, data):
|
268 |
+
model.eval()
|
269 |
+
out = model(data.x, data.edge_index.long(), data.edge_weight)
|
270 |
+
return out
|
271 |
+
|
272 |
+
def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette):
|
273 |
+
# take the node_to_mask indices to the beginning of the list
|
274 |
+
for i in range(len(indices_list)):
|
275 |
+
if node_to_mask in indices_list[i]:
|
276 |
+
index_to_pop = i
|
277 |
+
|
278 |
+
idxs = indices_list.pop(index_to_pop)
|
279 |
+
palette = unique_rgb_palette[index_to_pop]
|
280 |
+
temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0)
|
281 |
+
unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0)
|
282 |
+
return [idxs] + indices_list, unique_rgb_palette
|
283 |
+
|
284 |
+
def update_color(self, updated_color, idx):
|
285 |
+
"""
|
286 |
+
Takes a color and assigns it to the palette and the image.
|
287 |
+
"""
|
288 |
+
idx = int(idx)
|
289 |
+
color = updated_color[1:-1].split(",")
|
290 |
+
color = [int(num) for num in color]
|
291 |
+
|
292 |
+
index_list = self.same_indices[idx]
|
293 |
+
which_one = random.randint(0, len(index_list)-1)
|
294 |
+
idx_to_update = index_list[which_one]
|
295 |
+
|
296 |
+
unique_colors = self.run_model(self.input_data, self.target_color, idx_to_update, color)
|
297 |
+
|
298 |
+
global palette_of_the_design
|
299 |
+
palette_of_the_design = unique_colors
|
300 |
+
|
301 |
+
global all_node_colors
|
302 |
+
if torch.is_tensor(all_node_colors):
|
303 |
+
all_node_colors = all_node_colors.detach().numpy()
|
304 |
+
for i, index_list in enumerate(self.same_indices):
|
305 |
+
for index in index_list:
|
306 |
+
all_node_colors[index] = unique_colors[i]
|
307 |
+
|
308 |
+
self.generate_img_from_palette(palette=[color for color in all_node_colors])
|
309 |
+
main_image = Image.open("deneme.png")
|
310 |
+
gradio_elements = []
|
311 |
+
gradio_elements.append(gr.Image(main_image, height=256, width=256))
|
312 |
+
for i in range(len(self.same_indices)):
|
313 |
+
color = unique_colors[i]
|
314 |
+
image = Image.new("RGB", (512, 512), color=tuple(color))
|
315 |
+
gradio_elements.append(gr.Image(image, height=64, width=64))
|
316 |
+
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
|
317 |
+
gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
|
318 |
+
|
319 |
+
all_node_colors = torch.Tensor(all_node_colors) / 255
|
320 |
+
return tuple(gradio_elements)
|
321 |
+
|
322 |
+
def perform_reset(button_input):
|
323 |
+
global demo
|
324 |
+
global all_node_colors
|
325 |
+
gradio_elements = []
|
326 |
+
|
327 |
+
demo.demo_reset()
|
328 |
+
main_image = Image.open("deneme.png")
|
329 |
+
gradio_elements = []
|
330 |
+
gradio_elements.append(gr.Image(main_image, height=256, width=256))
|
331 |
+
|
332 |
+
for color in palette_of_the_design:
|
333 |
+
image = Image.new("RGB", (512, 512), color=tuple(color))
|
334 |
+
gradio_elements.append(gr.Image(image, height=64, width=64))
|
335 |
+
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
|
336 |
+
gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
|
337 |
+
|
338 |
+
return tuple(gradio_elements)
|
339 |
+
|
340 |
+
|
341 |
+
demo = Demo(graph_dataset=graph_test_dataset)
|
342 |
+
|
343 |
+
# Form a gradio template to display images and update the colors.
|
344 |
+
|
345 |
+
with gr.Blocks() as project_demo:
|
346 |
+
with gr.Row():
|
347 |
+
image = Image.open("deneme.png")
|
348 |
+
design = gr.Image(image, height=256, width=256)
|
349 |
+
|
350 |
+
with gr.Row():
|
351 |
+
with gr.Column(min_width=100):
|
352 |
+
image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0]))
|
353 |
+
image1_gr = gr.Image(image1, height=64, width=64)
|
354 |
+
string1 = "["+str(palette_of_the_design[0][0])+", "+ str(palette_of_the_design[0][1])+", " + str(palette_of_the_design[0][2])+"]"
|
355 |
+
color1_update = gr.Textbox(value=string1, min_width=64)
|
356 |
+
color1_button = gr.Button(value="Update Color 1", min_width=64)
|
357 |
+
with gr.Column(min_width=100):
|
358 |
+
image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1]))
|
359 |
+
image2_gr = gr.Image(image2, height=64, width=64)
|
360 |
+
string2 = "["+str(palette_of_the_design[1][0])+", "+ str(palette_of_the_design[1][1])+", " + str(palette_of_the_design[1][2])+"]"
|
361 |
+
color2_update = gr.Textbox(value=string2, min_width=64)
|
362 |
+
color2_button = gr.Button(value="Update Color 2", min_width=64)
|
363 |
+
with gr.Column(min_width=100):
|
364 |
+
image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2]))
|
365 |
+
image3_gr = gr.Image(image3, height=64, width=64)
|
366 |
+
string3 = "["+str(palette_of_the_design[2][0])+", "+ str(palette_of_the_design[2][1])+", " + str(palette_of_the_design[2][2])+"]"
|
367 |
+
color3_update = gr.Textbox(value=string3, min_width=64)
|
368 |
+
color3_button = gr.Button(value="Update Color 3", min_width=64)
|
369 |
+
with gr.Column(min_width=100):
|
370 |
+
image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3]))
|
371 |
+
image4_gr = gr.Image(image4, height=64, width=64)
|
372 |
+
string4 = "["+str(palette_of_the_design[3][0])+", "+ str(palette_of_the_design[3][1])+", " + str(palette_of_the_design[3][2])+"]"
|
373 |
+
color4_update = gr.Textbox(value=string4, min_width=64)
|
374 |
+
color4_button = gr.Button(value="Update Color 4", min_width=64)
|
375 |
+
with gr.Column(min_width=100):
|
376 |
+
image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4]))
|
377 |
+
image5_gr = gr.Image(image5, height=64, width=64)
|
378 |
+
string5 = "["+str(palette_of_the_design[4][0])+", "+ str(palette_of_the_design[4][1])+", " + str(palette_of_the_design[4][2])+"]"
|
379 |
+
color5_update = gr.Textbox(value=string5, min_width=64)
|
380 |
+
color5_button = gr.Button(value="Update Color 5", min_width=64)
|
381 |
+
|
382 |
+
with gr.Row():
|
383 |
+
reset_button = gr.Button(value="Reset the palette", min_width=64)
|
384 |
+
|
385 |
+
zero = gr.Number(value=0, visible=False)
|
386 |
+
one = gr.Number(value=1, visible=False)
|
387 |
+
two = gr.Number(value=2, visible=False)
|
388 |
+
three = gr.Number(value=3, visible=False)
|
389 |
+
four = gr.Number(value=4, visible=False)
|
390 |
+
color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
391 |
+
color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
392 |
+
color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
393 |
+
color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
394 |
+
color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
395 |
+
reset_button.click(fn=perform_reset, inputs=[reset_button], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
|
396 |
+
|
397 |
+
project_demo.launch()
|
color_palette/._dataset_processing.py
ADDED
Binary file (4.1 kB). View file
|
|
color_palette/__pycache__/cnn_dataset.cpython-39.pyc
ADDED
Binary file (2.76 kB). View file
|
|
color_palette/__pycache__/config.cpython-38.pyc
ADDED
Binary file (813 Bytes). View file
|
|
color_palette/__pycache__/config.cpython-39.pyc
ADDED
Binary file (734 Bytes). View file
|
|
color_palette/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (5.84 kB). View file
|
|
color_palette/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (5.24 kB). View file
|
|
color_palette/__pycache__/dataset_processing.cpython-39.pyc
ADDED
Binary file (12.9 kB). View file
|
|
color_palette/__pycache__/train_CNN.cpython-39.pyc
ADDED
Binary file (4.42 kB). View file
|
|
color_palette/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (9.73 kB). View file
|
|
color_palette/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (9.68 kB). View file
|
|
color_palette/all_one_hot_LR/test_gt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
|
3 |
+
size 48128
|
color_palette/all_one_hot_LR/test_preds.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b00a13c1ee009c195b1ffd4de3c502b4d34a72d54634e99a4e73a8d861a01292
|
3 |
+
size 48128
|
color_palette/all_one_hot_LR/test_preds_graph.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9b938717cae80d2fc604e38be8ffaec15bd95fd30c8294aec0753e79b51ffce
|
3 |
+
size 2312
|
color_palette/all_one_hot_LR/test_rgb_colors.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56dff61afc05392ad11c82c3e3d4809a1fd9c413196b4384bfade653d075423e
|
3 |
+
size 7772
|
color_palette/all_one_hot_LR_sequential/new_palettes.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6fcb136714cb725feccb5b6f1c66d219bdbd39ad4abcd31570665571edf5de3
|
3 |
+
size 12128
|
color_palette/all_one_hot_LR_sequential/original_palettes.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
|
3 |
+
size 12128
|
color_palette/all_one_hot_LR_sequential/test_gt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
|
3 |
+
size 48128
|
color_palette/all_one_hot_LR_sequential/test_preds.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b00a13c1ee009c195b1ffd4de3c502b4d34a72d54634e99a4e73a8d861a01292
|
3 |
+
size 48128
|
color_palette/app copy.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from regressor.config import config_to_use
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageFont, ImageDraw
|
4 |
+
from sklearn.linear_model import LinearRegression
|
5 |
+
from dataset import GraphDestijlDataset
|
6 |
+
from config import DataConfig
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
config = DataConfig()
|
13 |
+
model_name = config.model_name
|
14 |
+
dataset_root = config.dataset
|
15 |
+
image_folder = "img_folder"
|
16 |
+
|
17 |
+
if not os.path.exists(image_folder):
|
18 |
+
os.mkdir(image_folder)
|
19 |
+
|
20 |
+
model_weight_path = "../models/" + model_name + "/weights/best.pth"
|
21 |
+
|
22 |
+
palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
|
23 |
+
original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
|
24 |
+
|
25 |
+
graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
|
26 |
+
|
27 |
+
palette_of_the_design = [[0, 0, 0] for i in range(5)]
|
28 |
+
|
29 |
+
class Demo:
|
30 |
+
def __init__(self, graph_dataset):
|
31 |
+
self.dataset = graph_dataset
|
32 |
+
|
33 |
+
first_sample_idx = random.randint(0, len(self.dataset))
|
34 |
+
new_data, target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
|
35 |
+
self.all_node_colors = also_normal_values
|
36 |
+
self.same_indices = None
|
37 |
+
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
|
38 |
+
|
39 |
+
def generate_img_from_palette(self, palette, canvas_size=512, is_first=False):
|
40 |
+
palette = np.array(palette).astype('int')
|
41 |
+
rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette]
|
42 |
+
|
43 |
+
if is_first:
|
44 |
+
self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette)
|
45 |
+
else:
|
46 |
+
_, unique_colors, _ = self.return_all_same_colors(palette=palette)
|
47 |
+
|
48 |
+
# assign the current palette using global keyword
|
49 |
+
global palette_of_the_design
|
50 |
+
palette_of_the_design = unique_colors
|
51 |
+
|
52 |
+
# Set the background color and create an empty PIL Image to fill with shapes and text
|
53 |
+
image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg)
|
54 |
+
# Save background image
|
55 |
+
|
56 |
+
title = "Lorem Ipsum Dolor"
|
57 |
+
undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..."
|
58 |
+
|
59 |
+
draw = ImageDraw.Draw(image)
|
60 |
+
# Set settings for the fonts
|
61 |
+
font_title = ImageFont.truetype("Arial.ttf", 32)
|
62 |
+
title_width, title_height = draw.textsize(title, font=font_title)
|
63 |
+
title_x = (canvas_size - title_width) // 2
|
64 |
+
title_y = (canvas_size - title_height) // 2 - 100
|
65 |
+
|
66 |
+
font_undertitle = ImageFont.truetype("Arial.ttf", 15)
|
67 |
+
text_width, text_height = draw.textsize(undertitle, font=font_undertitle)
|
68 |
+
undertitle_x = (canvas_size - text_width) // 2
|
69 |
+
undertitle_y = (canvas_size - text_height) // 2 - 50
|
70 |
+
# Draw titles
|
71 |
+
draw.text((title_x, title_y), title, fill=rgb_text, font=font_title)
|
72 |
+
draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle)
|
73 |
+
|
74 |
+
# Draw the circle
|
75 |
+
rad = random.randint(30, 70)
|
76 |
+
x = random.randint(400, 512-(rad+10))
|
77 |
+
y = random.randint(10, title_y-(rad+10))
|
78 |
+
|
79 |
+
draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle)
|
80 |
+
|
81 |
+
# Draw the image
|
82 |
+
for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]):
|
83 |
+
x = 512-((j+1)*60)
|
84 |
+
y = 512-((j+1)*60)
|
85 |
+
|
86 |
+
if j == 0:
|
87 |
+
rad = 80
|
88 |
+
draw.rectangle((x, y, x+rad, y+rad), fill=color)
|
89 |
+
else:
|
90 |
+
rad = 40
|
91 |
+
draw.rectangle((x, y, x+rad, y+rad), fill=color)
|
92 |
+
|
93 |
+
image.save(os.path.join("deneme.png"))
|
94 |
+
|
95 |
+
def run_model(self):
|
96 |
+
also_normal_values = np.squeeze(np.stack(also_normal_values, axis=0)) # rgb
|
97 |
+
same_indices_list, unique_colors, first_indices = return_all_same_colors(also_normal_values)
|
98 |
+
|
99 |
+
# move node_to_mask to first place in input_data and unique colors
|
100 |
+
# same_indices_list, unique_colors = rearrange_indices_list(same_indices_list, node_to_mask, unique_colors)
|
101 |
+
map_node_to_mask = -1
|
102 |
+
for i, idxs in enumerate(same_indices_list):
|
103 |
+
if node_to_mask in idxs:
|
104 |
+
map_node_to_mask = i
|
105 |
+
|
106 |
+
original_palettes.append(unique_colors.copy())
|
107 |
+
for i, indices in enumerate(same_indices_list):
|
108 |
+
|
109 |
+
if i == 0:
|
110 |
+
# update the color [0.15, 0.4908, 0.73]
|
111 |
+
selected_color = torch.Tensor([254/255, 254/255, 224/255])
|
112 |
+
cube_num_of_selected = rgb2cube(selected_color*255)
|
113 |
+
one_hot = np.zeros((64,))
|
114 |
+
one_hot[int(cube_num_of_selected)] = 1.0
|
115 |
+
node_to_recommend = map_node_to_mask
|
116 |
+
input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot)
|
117 |
+
unique_colors[map_node_to_mask] = selected_color
|
118 |
+
else:
|
119 |
+
if i == map_node_to_mask:
|
120 |
+
zeroth_bin = 0
|
121 |
+
indices = same_indices_list[zeroth_bin]
|
122 |
+
node_to_recommend = 0
|
123 |
+
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
|
124 |
+
node_to_mask = indices[0]
|
125 |
+
else:
|
126 |
+
node_to_recommend = i
|
127 |
+
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
|
128 |
+
node_to_mask = indices[0]
|
129 |
+
|
130 |
+
out = forward_pass(model, input_data) # input data has one-hot color features
|
131 |
+
if torch.is_tensor(node_to_mask):
|
132 |
+
node_to_mask = node_to_mask.item()
|
133 |
+
values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation
|
134 |
+
prediction = values_indices.detach().numpy()[2]
|
135 |
+
# construct a palette using unique RGB palette and one-hot representation of the prediction cube.
|
136 |
+
feature_vector = create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend)
|
137 |
+
# map cube to rgb color space using the regressor
|
138 |
+
recommendation = regressor.predict(feature_vector)[0]
|
139 |
+
# we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information.
|
140 |
+
# update the color in the palette and run the algorithm for rest of the palette.
|
141 |
+
# for that, first map the color to cube and convert to one_hot
|
142 |
+
input_data, unique_colors = update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend)
|
143 |
+
# recursively do this here.
|
144 |
+
# save the results.
|
145 |
+
|
146 |
+
def rgb2cube(color):
|
147 |
+
intervals = np.arange(0, 256, 256//4)
|
148 |
+
cube_coordinates = []
|
149 |
+
for channel in color:
|
150 |
+
i = 0
|
151 |
+
for j, value in enumerate(intervals):
|
152 |
+
if value < channel:
|
153 |
+
i = j
|
154 |
+
cube_coordinates.append(i)
|
155 |
+
|
156 |
+
cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4
|
157 |
+
return cube_num
|
158 |
+
|
159 |
+
def cube2rgb(self, cube_num):
|
160 |
+
"""
|
161 |
+
Return the start of the ranges
|
162 |
+
"""
|
163 |
+
cube_num = int(cube_num)
|
164 |
+
intervals = np.arange(0, 256, 256//4)
|
165 |
+
coor2 = cube_num // 16
|
166 |
+
coor1 = (cube_num - coor2*4*4) // 4
|
167 |
+
coor0 = cube_num - coor2*4*4 - coor1*4
|
168 |
+
return [intervals[coor0], intervals[coor1], intervals[coor2]]
|
169 |
+
|
170 |
+
def return_all_same_colors(self, palette):
|
171 |
+
|
172 |
+
indices_list = [[],[],[],[],[]]
|
173 |
+
unique_colors, first_indices = np.unique(palette, axis=0, return_index=True)
|
174 |
+
|
175 |
+
unique_colors = np.array(unique_colors)
|
176 |
+
all_colors = np.array(palette)
|
177 |
+
|
178 |
+
for idx, color in enumerate(unique_colors):
|
179 |
+
for node_num, element in enumerate(all_colors):
|
180 |
+
if np.equal(color, element).all():
|
181 |
+
indices_list[idx].append(node_num)
|
182 |
+
|
183 |
+
# these palettes and indices also include the masked color
|
184 |
+
return indices_list, unique_colors, first_indices
|
185 |
+
|
186 |
+
def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs):
|
187 |
+
# convert prediction to one-hot vector
|
188 |
+
cube_num_of_the_changed_color = self.rgb2cube(recommendation*255)
|
189 |
+
one_hot = np.zeros((64,))
|
190 |
+
one_hot[int(cube_num_of_the_changed_color)] = 1.0
|
191 |
+
|
192 |
+
# update the feature vector accordingly for all the same colors
|
193 |
+
for idx in indices_list[idx_to_idxs]:
|
194 |
+
input_data.x[idx, 4:] = torch.Tensor(one_hot)
|
195 |
+
|
196 |
+
# update the unique color vector
|
197 |
+
unique_rgb_palette[idx_to_idxs] = recommendation
|
198 |
+
return input_data, unique_rgb_palette
|
199 |
+
|
200 |
+
|
201 |
+
def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask):
|
202 |
+
one_hot = np.zeros((64,))
|
203 |
+
one_hot[int(cube_num)] = 1.0
|
204 |
+
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
|
205 |
+
feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0)
|
206 |
+
return feature_vector.reshape(1, -1)
|
207 |
+
|
208 |
+
def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask):
|
209 |
+
one_hot = np.zeros((64,))
|
210 |
+
one_hot[int(cube_num)] = 1.0
|
211 |
+
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
|
212 |
+
new_input_data = []
|
213 |
+
for color in removed_palette:
|
214 |
+
color_cube_num = self.rgb2cube(color*255)
|
215 |
+
empty_arr = np.zeros((64,))
|
216 |
+
empty_arr[int(color_cube_num)] = 1.0
|
217 |
+
new_input_data.append(empty_arr)
|
218 |
+
|
219 |
+
feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0)
|
220 |
+
return feature_vector.reshape(1, -1)
|
221 |
+
|
222 |
+
def forward_pass(self, model, data):
|
223 |
+
model.eval()
|
224 |
+
out = model(data.x, data.edge_index.long(), data.edge_weight)
|
225 |
+
return out
|
226 |
+
|
227 |
+
def train_regressor(self, train_loader):
|
228 |
+
X = []
|
229 |
+
y = []
|
230 |
+
for i, (input_data, target) in enumerate(train_loader):
|
231 |
+
input_data = np.squeeze(input_data)
|
232 |
+
target = np.squeeze(target)
|
233 |
+
X.append(input_data)
|
234 |
+
y.append(target)
|
235 |
+
|
236 |
+
X = np.stack(X, axis=0)
|
237 |
+
y = np.squeeze(np.stack(y, axis=0))
|
238 |
+
|
239 |
+
print("Before regressor train!\n")
|
240 |
+
|
241 |
+
reg = LinearRegression().fit(X, y)
|
242 |
+
|
243 |
+
return reg
|
244 |
+
|
245 |
+
def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette):
|
246 |
+
# take the node_to_mask indices to the beginning of the list
|
247 |
+
for i in range(len(indices_list)):
|
248 |
+
if node_to_mask in indices_list[i]:
|
249 |
+
index_to_pop = i
|
250 |
+
|
251 |
+
idxs = indices_list.pop(index_to_pop)
|
252 |
+
palette = unique_rgb_palette[index_to_pop]
|
253 |
+
temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0)
|
254 |
+
unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0)
|
255 |
+
return [idxs] + indices_list, unique_rgb_palette
|
256 |
+
|
257 |
+
def update_color(self, updated_color, idx):
|
258 |
+
"""
|
259 |
+
Takes a color and assigns it to the palette and the image.
|
260 |
+
"""
|
261 |
+
idx = int(idx)
|
262 |
+
color = updated_color[1:-1].split(",")
|
263 |
+
color = [int(num) for num in color]
|
264 |
+
|
265 |
+
global palette_of_the_design
|
266 |
+
palette_of_the_design[idx] = color
|
267 |
+
|
268 |
+
idxs_to_change = self.same_indices[idx]
|
269 |
+
for index in idxs_to_change:
|
270 |
+
self.all_node_colors[index] = torch.Tensor(color)/255
|
271 |
+
|
272 |
+
self.generate_img_from_palette(palette=[color.detach().numpy()*255 for color in self.all_node_colors])
|
273 |
+
|
274 |
+
image = Image.new("RGB", (512, 512), color=tuple(color))
|
275 |
+
main_image = Image.open("deneme.png")
|
276 |
+
return gr.Image(main_image, height=256, width=256), gr.Image(image, height=64, width=64), gr.Textbox(value=str(color), min_width=64)
|
277 |
+
|
278 |
+
if __name__ == "__main__":
|
279 |
+
demo = Demo(graph_dataset=graph_test_dataset)
|
280 |
+
|
281 |
+
# Form a gradio template to display images and update the colors.
|
282 |
+
|
283 |
+
with gr.Blocks() as project_demo:
|
284 |
+
with gr.Row():
|
285 |
+
image = Image.open("deneme.png")
|
286 |
+
design = gr.Image(image, height=256, width=256)
|
287 |
+
|
288 |
+
with gr.Row():
|
289 |
+
with gr.Column(min_width=100):
|
290 |
+
image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0]))
|
291 |
+
image1_gr = gr.Image(image1, height=64, width=64)
|
292 |
+
color1_update = gr.Textbox(value=str(palette_of_the_design[0]).replace(" ", ","), min_width=64)
|
293 |
+
color1_button = gr.Button(value="Update Color 1", min_width=64)
|
294 |
+
with gr.Column(min_width=100):
|
295 |
+
image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1]))
|
296 |
+
image2_gr = gr.Image(image2, height=64, width=64)
|
297 |
+
color2_update = gr.Textbox(value=str(palette_of_the_design[1]).replace(" ", ","), min_width=64)
|
298 |
+
color2_button = gr.Button(value="Update Color 2", min_width=64)
|
299 |
+
with gr.Column(min_width=100):
|
300 |
+
image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2]))
|
301 |
+
image3_gr = gr.Image(image3, height=64, width=64)
|
302 |
+
color3_update = gr.Textbox(value=str(palette_of_the_design[2]).replace(" ", ","), min_width=64)
|
303 |
+
color3_button = gr.Button(value="Update Color 3", min_width=64)
|
304 |
+
with gr.Column(min_width=100):
|
305 |
+
image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3]))
|
306 |
+
image4_gr = gr.Image(image4, height=64, width=64)
|
307 |
+
color4_update = gr.Textbox(value=str(palette_of_the_design[3]).replace(" ", ","), min_width=64)
|
308 |
+
color4_button = gr.Button(value="Update Color 4", min_width=64)
|
309 |
+
with gr.Column(min_width=100):
|
310 |
+
image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4]))
|
311 |
+
image5_gr = gr.Image(image5, height=64, width=64)
|
312 |
+
color5_update = gr.Textbox(value=str(palette_of_the_design[4]).replace(" ", ","), min_width=64)
|
313 |
+
color5_button = gr.Button(value="Update Color 5", min_width=64)
|
314 |
+
|
315 |
+
zero = gr.Number(value=0, visible=False)
|
316 |
+
one = gr.Number(value=1, visible=False)
|
317 |
+
two = gr.Number(value=2, visible=False)
|
318 |
+
three = gr.Number(value=3, visible=False)
|
319 |
+
four = gr.Number(value=4, visible=False)
|
320 |
+
color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update])
|
321 |
+
color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image2_gr, color2_update])
|
322 |
+
color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image3_gr, color3_update])
|
323 |
+
color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image4_gr, color4_update])
|
324 |
+
color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image5_gr, color5_update])
|
325 |
+
|
326 |
+
project_demo.launch()
|
color_palette/bash_scripts/training.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM_FILE=0
|
2 |
+
TO_FILE=11
|
3 |
+
|
4 |
+
file=$FROM_FILE
|
5 |
+
device_idx=2
|
6 |
+
|
7 |
+
free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i $device_idx | grep -Eo [0-9]+)
|
8 |
+
|
9 |
+
while [ $file -le $TO_FILE ]
|
10 |
+
do
|
11 |
+
# if [ $free_mem -lt 13000 ]; then
|
12 |
+
# while [ $free_mem -lt 13000 ]; do
|
13 |
+
# sleep 10
|
14 |
+
# free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i $device_idx | grep -Eo [0-9]+)
|
15 |
+
# done
|
16 |
+
# fi
|
17 |
+
|
18 |
+
echo "Running experiment for conf$file.yaml"
|
19 |
+
#nohup python train.py --config_file config/conf$file.yaml > "config/out_$file.txt" &
|
20 |
+
python evaluate.py --config_file config/conf$file.yaml
|
21 |
+
file=$(($file+1))
|
22 |
+
#sleep 10
|
23 |
+
|
24 |
+
done
|
color_palette/cnn_dataset.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader
|
2 |
+
from sklearn.model_selection import train_test_split
|
3 |
+
import skimage.color as scicolor
|
4 |
+
from utils import *
|
5 |
+
|
6 |
+
class PreviewDataset(Dataset):
|
7 |
+
def __init__(self, root="../destijl_dataset/rgba_dataset/",
|
8 |
+
transform=None, test=False, color_space="RGB",
|
9 |
+
input_color_space="RGB",
|
10 |
+
is_classification=False,
|
11 |
+
normalize_cielab=True,
|
12 |
+
normalize_rgb=True):
|
13 |
+
|
14 |
+
self.test = test
|
15 |
+
self.sample_filenames = os.listdir(root+"00_preview_cropped")
|
16 |
+
self.transform = transform
|
17 |
+
self.img_dir = root
|
18 |
+
self.color_space = color_space
|
19 |
+
self.is_classification = is_classification
|
20 |
+
self.input_color_space = input_color_space
|
21 |
+
self.normalize_cielab = normalize_cielab
|
22 |
+
self.normalize_rgb = normalize_rgb
|
23 |
+
|
24 |
+
self.train_filenames, self.test_filenames = train_test_split(self.sample_filenames,
|
25 |
+
test_size=0.2,
|
26 |
+
random_state=42)
|
27 |
+
def __len__(self):
|
28 |
+
if self.test:
|
29 |
+
return len(self.test_filenames)
|
30 |
+
else:
|
31 |
+
return len(self.train_filenames)
|
32 |
+
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
|
35 |
+
path_idx = "{:04d}".format(idx)
|
36 |
+
img_path = os.path.join(self.img_dir, "00_preview_cropped/" + self.sample_filenames[idx])
|
37 |
+
|
38 |
+
image = np.array(Image.open(img_path))
|
39 |
+
# Convert image to lab if the input space is CIELab.
|
40 |
+
# Image is a numpy array always. Convert to tensor at the end.
|
41 |
+
if self.input_color_space == "CIELab":
|
42 |
+
image = scicolor.rgb2lab(image)
|
43 |
+
image = torch.from_numpy(image)
|
44 |
+
# if self.normalize_cielab:
|
45 |
+
# image = torch.from_numpy(image)
|
46 |
+
# image = normalize_CIELab(image)
|
47 |
+
else:
|
48 |
+
image = torch.from_numpy(image)
|
49 |
+
|
50 |
+
|
51 |
+
# Apply kmeans on RGB image always.
|
52 |
+
bg_path = os.path.join("../destijl_dataset/01_background/" + self.sample_filenames[idx])
|
53 |
+
# Most dominant color in RGB.
|
54 |
+
color = self.kmeans_for_bg(bg_path)[0]
|
55 |
+
|
56 |
+
# If output is in CIELab space but input is in RGB, convert target to CIELab also.
|
57 |
+
if self.color_space == "CIELab":
|
58 |
+
target_color = torch.squeeze(torch.tensor(RGB2CIELab(color.astype(np.int32))))
|
59 |
+
# if self.normalize_cielab:
|
60 |
+
# target_color = normalize_CIELab(target_color)
|
61 |
+
# Input and output is in RGB space or input and output is in CIELab space.
|
62 |
+
# If Input is in CIELab and output is in RGB, than this is also valid since dataset is in RGB.
|
63 |
+
else:
|
64 |
+
target_color = torch.squeeze(torch.tensor(color))
|
65 |
+
|
66 |
+
if self.is_classification:
|
67 |
+
target_color = [torch.zeros(256), torch.zeros(256), torch.zeros(256)]
|
68 |
+
target_color[0][color[0]] = 1
|
69 |
+
target_color[1][color[1]] = 1
|
70 |
+
target_color[2][color[2]] = 1
|
71 |
+
|
72 |
+
if self.transform:
|
73 |
+
# Reshape the image if not in (C, H, W) form.
|
74 |
+
if image.shape[0] != 3:
|
75 |
+
image = image.reshape(-1, image.shape[0], image.shape[1]).type("torch.FloatTensor")
|
76 |
+
# Apply the transformation
|
77 |
+
image = self.transform(image)
|
78 |
+
|
79 |
+
if self.normalize_rgb:
|
80 |
+
image /= 255
|
81 |
+
|
82 |
+
if self.color_space == "RGB" and self.normalize_rgb:
|
83 |
+
target_color /= 255
|
84 |
+
|
85 |
+
if self.normalize_cielab:
|
86 |
+
# we will only use lightness
|
87 |
+
target_color /= 100
|
88 |
+
|
89 |
+
return image, target_color
|
90 |
+
|
91 |
+
def kmeans_for_bg(self, bg_path):
|
92 |
+
image = cv2.imread(bg_path)
|
93 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
94 |
+
n_colors = 1
|
95 |
+
|
96 |
+
# Apply KMeans to the text area
|
97 |
+
pixels = np.float32(image.reshape(-1, 3))
|
98 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
|
99 |
+
flags = cv2.KMEANS_RANDOM_CENTERS
|
100 |
+
|
101 |
+
_, labels, palette = cv2.kmeans(pixels, n_colors, None, criteria, 10, flags)
|
102 |
+
palette = np.asarray(palette, dtype=np.int64) # RGB
|
103 |
+
|
104 |
+
return palette
|
color_palette/colorCNN.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
TODO:
|
4 |
+
* Make the white backgrounds transparent.
|
5 |
+
* Locate the images in the bounding boxes in preview images in white background.
|
6 |
+
Cut the images and paste to the location on the decoration layaer.
|
7 |
+
* The pasting order: Decoration will have white bg. Paste image. Paste text.
|
8 |
+
* When the white bg images are done, feed them to CNN.
|
9 |
+
* The output will be the missing color which is the background color.
|
10 |
+
* Use CIELab distances to train
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
from utils import *
|
17 |
+
from dataset_processing import ProcessedDeStijl
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
class DestijlProcessorCNN():
|
21 |
+
|
22 |
+
def __init__(self, data_path):
|
23 |
+
self.path_dict = {
|
24 |
+
'preview': data_path + '/00_preview/',
|
25 |
+
'image': data_path + '/02_image/',
|
26 |
+
'decoration': data_path + '/03_decoration/',
|
27 |
+
'text': data_path + '/04_text/',
|
28 |
+
}
|
29 |
+
|
30 |
+
self.rgba_path_dict = {
|
31 |
+
'preview': data_path + '/rgba_dataset/00_preview/',
|
32 |
+
'image': data_path + '/rgba_dataset/02_image/',
|
33 |
+
'decoration': data_path + '/rgba_dataset/03_decoration/',
|
34 |
+
'text': data_path + '/rgba_dataset/04_text/',
|
35 |
+
'temporary': data_path + '/rgba_dataset/05_temporary/',
|
36 |
+
'cropped_preview': data_path + '/rgba_dataset/00_preview_cropped/',
|
37 |
+
}
|
38 |
+
|
39 |
+
self.xml_path_dict = {
|
40 |
+
'preview': data_path + '/xmls/00_preview/',
|
41 |
+
'image': data_path + '/xmls/02_image/',
|
42 |
+
'decoration': data_path + '/xmls/03_decoration/',
|
43 |
+
'text': data_path + '/xmls/04_text/',
|
44 |
+
}
|
45 |
+
|
46 |
+
self.processed_dataset = ProcessedDeStijl("../destijl_dataset")
|
47 |
+
|
48 |
+
def whitebg_to_transparent(self, img_path, layer):
|
49 |
+
|
50 |
+
"""
|
51 |
+
WORKING
|
52 |
+
"""
|
53 |
+
image_bgr = cv2.imread(img_path)
|
54 |
+
image_num = img_path[-8:]
|
55 |
+
# get the image dimensions (height, width and channels)
|
56 |
+
h, w, c = image_bgr.shape
|
57 |
+
# append Alpha channel -- required for BGRA (Blue, Green, Red, Alpha)
|
58 |
+
image_bgra = np.concatenate([image_bgr, np.full((h, w, 1), 255, dtype=np.uint8)], axis=-1)
|
59 |
+
# create a mask where white pixels ([255, 255, 255]) are True
|
60 |
+
white = np.all(image_bgr == [255, 255, 255], axis=-1)
|
61 |
+
# change the values of Alpha to 0 for all the white pixels
|
62 |
+
image_bgra[white, -1] = 0
|
63 |
+
# save the image
|
64 |
+
cv2.imwrite(self.rgba_path_dict[layer]+image_num, image_bgra)
|
65 |
+
|
66 |
+
def locate_images_in_image_layer(self, idx):
|
67 |
+
|
68 |
+
method = cv2.TM_SQDIFF_NORMED
|
69 |
+
path_idx = "{:04d}".format(idx)
|
70 |
+
preview_img = cv2.imread(self.path_dict["preview"]+path_idx+".png")
|
71 |
+
preview_bboxes = VOC2bbox(self.xml_path_dict["image"]+path_idx+".xml")[1]
|
72 |
+
image_img = cv2.imread(self.path_dict["image"]+path_idx+".png")
|
73 |
+
|
74 |
+
boxes = []
|
75 |
+
design_boxes = []
|
76 |
+
for box in preview_bboxes:
|
77 |
+
xmin = box[0][0]
|
78 |
+
xmax = box[1][0]
|
79 |
+
ymin = box[0][1]
|
80 |
+
ymax = box[2][1]
|
81 |
+
cropped_img = preview_img[ymin:ymax, xmin:xmax]
|
82 |
+
|
83 |
+
if(cropped_img.shape[0] > image_img.shape[0]):
|
84 |
+
diff_x = abs(cropped_img.shape[0] - image_img.shape[0])
|
85 |
+
image_img = cv2.copyMakeBorder(image_img, diff_x//2+5, diff_x//2+5, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
86 |
+
|
87 |
+
if(cropped_img.shape[1] > image_img.shape[1]):
|
88 |
+
diff_y = abs(cropped_img.shape[1] - image_img.shape[1])
|
89 |
+
image_img = cv2.copyMakeBorder(image_img, 0, 0, diff_y//2+5, diff_y//2+5, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
90 |
+
|
91 |
+
result = cv2.matchTemplate(cropped_img, image_img, method)
|
92 |
+
mn,_,mnLoc,_ = cv2.minMaxLoc(result)
|
93 |
+
MPx,MPy = mnLoc
|
94 |
+
trows,tcols = cropped_img.shape[:2]
|
95 |
+
boxes.append([MPx, MPx+tcols, MPy, MPy+trows])
|
96 |
+
design_boxes.append([xmin, xmax, ymin, ymax])
|
97 |
+
|
98 |
+
self.check_boxes(design_boxes, idx)
|
99 |
+
return boxes, design_boxes
|
100 |
+
|
101 |
+
def check_boxes(self, bboxes, idx):
|
102 |
+
path_idx = "{:04d}".format(idx)
|
103 |
+
im = cv2.imread("../destijl_dataset/02_image/" + path_idx + ".png")
|
104 |
+
for box in bboxes:
|
105 |
+
# [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
|
106 |
+
xmin = box[0]
|
107 |
+
xmax = box[1]
|
108 |
+
ymin = box[2]
|
109 |
+
ymax = box[3]
|
110 |
+
cv2.rectangle(im,(xmin, ymin),(xmax, ymax),(255,0,0),2)
|
111 |
+
cv2.imwrite("check_boxes.jpg", im)
|
112 |
+
|
113 |
+
def map_image_coordinates(self, text_coordinate, design_text_coordinate, design_img_coordinates, design_size, text_size):
|
114 |
+
prev_x, prev_y = design_size
|
115 |
+
text_x, text_y = text_size
|
116 |
+
|
117 |
+
design_x, design_y = design_text_coordinate[0]
|
118 |
+
text_x, text_y = text_coordinate[0]
|
119 |
+
|
120 |
+
diff_x = text_x - design_x
|
121 |
+
diff_y = text_y - design_y
|
122 |
+
|
123 |
+
new_coordinates = []
|
124 |
+
for coordinate in design_img_coordinates:
|
125 |
+
for i in range(len(coordinate)):
|
126 |
+
if i < 2:
|
127 |
+
coordinate[i] = int(coordinate[i] + diff_x)
|
128 |
+
else:
|
129 |
+
coordinate[i] = int(coordinate[i] + diff_y)
|
130 |
+
|
131 |
+
if coordinate[i] < 0:
|
132 |
+
coordinate[i] *= -1
|
133 |
+
new_coordinates.append(coordinate)
|
134 |
+
|
135 |
+
return new_coordinates
|
136 |
+
|
137 |
+
def convert_to_min_max_coordinate(self, box):
|
138 |
+
xmin, ymin = np.min(box, axis=0)
|
139 |
+
xmax, ymax = np.max(box, axis=0)
|
140 |
+
return [int(xmin), int(xmax), int(ymin), int(ymax)]
|
141 |
+
|
142 |
+
def paste_onto_decoration_layer(self, idx):
|
143 |
+
path_idx = "{:04d}".format(idx)
|
144 |
+
preview_path = self.path_dict["preview"] + path_idx + ".png"
|
145 |
+
img_path = self.path_dict["image"] + path_idx + ".png"
|
146 |
+
decoration_path = self.path_dict["decoration"] + path_idx + ".png"
|
147 |
+
text_path = self.path_dict["text"] + path_idx + ".png"
|
148 |
+
white_bg_text_path = self.rgba_path_dict["text"] + path_idx + ".png"
|
149 |
+
white_bg_img_path = self.rgba_path_dict["image"] + path_idx + ".png"
|
150 |
+
|
151 |
+
img = cv2.imread(white_bg_img_path)
|
152 |
+
decoration_img = cv2.imread(decoration_path)
|
153 |
+
prev = cv2.imread(preview_path)
|
154 |
+
|
155 |
+
design_size = (prev.shape[0], prev.shape[1])
|
156 |
+
text_size = (decoration_img.shape[0], decoration_img.shape[1])
|
157 |
+
|
158 |
+
image_boxes, design_image_boxes = self.locate_images_in_image_layer(idx)
|
159 |
+
|
160 |
+
text_bboxes, white_bg_text_boxes, texts = self.processed_dataset.extract_text_bbox(text_path, preview_path)
|
161 |
+
text_bboxes_from_design, composed_text_palettes = self.processed_dataset.extract_text_directly(preview_path, texts)
|
162 |
+
|
163 |
+
if not text_bboxes_from_design or not white_bg_text_boxes:
|
164 |
+
pass
|
165 |
+
else:
|
166 |
+
design_text_coordinate = text_bboxes_from_design[0]
|
167 |
+
text_coordinate = white_bg_text_boxes[0]
|
168 |
+
new_image_boxes = self.map_image_coordinates(text_coordinate, design_text_coordinate, design_image_boxes, design_size, text_size)
|
169 |
+
|
170 |
+
white_bg = np.zeros( [decoration_img.shape[0], decoration_img.shape[1], 3] ,dtype=np.uint8)
|
171 |
+
white_bg.fill(255)
|
172 |
+
|
173 |
+
cv2.imwrite('bg.jpg', white_bg)
|
174 |
+
|
175 |
+
white_bg = Image.open('bg.jpg')
|
176 |
+
|
177 |
+
decoration_overlay = Image.open(self.rgba_path_dict["decoration"] + path_idx + ".png")
|
178 |
+
text_overlay = Image.open(white_bg_text_path)
|
179 |
+
white_bg.paste(decoration_overlay, mask=decoration_overlay)
|
180 |
+
|
181 |
+
for j, box in enumerate(new_image_boxes):
|
182 |
+
xmin1, xmax1, ymin1, ymax1 = box # box place on decoration
|
183 |
+
xmin2, xmax2, ymin2, ymax2 = image_boxes[j] # box place on image
|
184 |
+
cropped_img = img[ymin2:ymax2, xmin2:xmax2]
|
185 |
+
|
186 |
+
cv2.imwrite(self.rgba_path_dict["temporary"] + path_idx + ".png", cropped_img)
|
187 |
+
self.whitebg_to_transparent(self.rgba_path_dict["temporary"] + path_idx + ".png", "temporary")
|
188 |
+
cropped_img = Image.open(self.rgba_path_dict["temporary"] + path_idx + ".png")
|
189 |
+
|
190 |
+
offset = (xmin1, ymin1)
|
191 |
+
white_bg.paste(cropped_img, offset, mask=cropped_img)
|
192 |
+
|
193 |
+
white_bg.paste(text_overlay, mask=text_overlay)
|
194 |
+
white_bg.save(self.rgba_path_dict["preview"] + path_idx + ".png")
|
195 |
+
|
196 |
+
def pipeline(self):
|
197 |
+
for idx in range(550, 706):
|
198 |
+
print("Sample: ", idx)
|
199 |
+
path_idx = "{:04d}".format(idx)
|
200 |
+
|
201 |
+
img_path = self.path_dict["image"]+path_idx+".png"
|
202 |
+
text_path = self.path_dict["text"]+path_idx+".png"
|
203 |
+
decoration_path = self.path_dict["decoration"]+path_idx+".png"
|
204 |
+
|
205 |
+
self.whitebg_to_transparent(img_path, "image")
|
206 |
+
self.whitebg_to_transparent(text_path, "text")
|
207 |
+
self.whitebg_to_transparent(decoration_path, "decoration")
|
208 |
+
|
209 |
+
self.paste_onto_decoration_layer(idx)
|
210 |
+
|
211 |
+
def resize_images(self):
|
212 |
+
for idx in range(84, 705):
|
213 |
+
path_idx = "{:04d}".format(idx)
|
214 |
+
|
215 |
+
if os.path.exists(self.rgba_path_dict["preview"]+path_idx+".png"):
|
216 |
+
print(idx)
|
217 |
+
|
218 |
+
method = cv2.TM_SQDIFF_NORMED
|
219 |
+
big_image = cv2.imread(self.rgba_path_dict["preview"]+path_idx+".png")
|
220 |
+
small_image = cv2.imread(self.path_dict["preview"]+path_idx+".png")
|
221 |
+
if(small_image.shape[0] > big_image.shape[0]):
|
222 |
+
diff_x = abs(big_image.shape[0] - small_image.shape[0])
|
223 |
+
big_image = cv2.copyMakeBorder(big_image, diff_x//2+5, diff_x//2+5, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
224 |
+
|
225 |
+
if(small_image.shape[1] > big_image.shape[1]):
|
226 |
+
diff_y = abs(big_image.shape[1] - small_image.shape[1])
|
227 |
+
big_image = cv2.copyMakeBorder(big_image, 0, 0, diff_y//2+5, diff_y//2+5, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
228 |
+
|
229 |
+
result = cv2.matchTemplate(small_image, big_image, method)
|
230 |
+
mn,_,mnLoc,_ = cv2.minMaxLoc(result)
|
231 |
+
MPx,MPy = mnLoc
|
232 |
+
trows,tcols = small_image.shape[:2]
|
233 |
+
box = [MPx, MPx+tcols, MPy, MPy+trows]
|
234 |
+
new_img = big_image[box[2]:box[3], box[0]:box[1]]
|
235 |
+
cv2.imwrite(self.rgba_path_dict["cropped_preview"]+path_idx+".png", new_img)
|
236 |
+
|
237 |
+
processor = DestijlProcessorCNN("../destijl_dataset")
|
238 |
+
processor.resize_images()
|
color_palette/config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class DataConfig:
|
5 |
+
|
6 |
+
dataset = "shape_dataset_circle_image/"
|
7 |
+
model_name = "ColorAttentionCircleImageLayer_random_mask"
|
8 |
+
data_type = "processed_rgb_toy_dataset_circle_image_color_and_layer"
|
9 |
+
|
10 |
+
# dataset = "../shape_dataset_lightness_circle/"
|
11 |
+
# model_name = "ColorAttentionLightnessCircle_random_mask_class"
|
12 |
+
# data_type = "processed_rgb_toy_dataset_lightness_circle_color_and_layer"
|
13 |
+
|
14 |
+
# dataset = "../destijl_dataset/"
|
15 |
+
# model_name = "ColorGNN_random_mask_new"
|
16 |
+
# data_type = "processed_rgb_color_and_layer"
|
17 |
+
|
18 |
+
feature_size = 68
|
19 |
+
loss_function = "CrossEntropy"
|
20 |
+
|
21 |
+
device = "cpu"
|
22 |
+
lr = 0.005
|
23 |
+
batch_size = 1
|
24 |
+
weight_decay = 0
|
25 |
+
num_epoch = 300
|
26 |
+
|
27 |
+
node_to_mask = -1
|
28 |
+
is_classification = True
|
29 |
+
layers_to_consider = ["background", "text", "image"]
|
color_palette/config/conf.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "ColorAttentionColor_Layer"
|
2 |
+
data_type: "processed_rgb_toy_dataset_color_and_layer"
|
3 |
+
feature_size: 4
|
4 |
+
|
5 |
+
threshold_for_neighbours: 1
|
6 |
+
|
7 |
+
loss_function: "MSE"
|
8 |
+
|
9 |
+
device: "cuda:2"
|
10 |
+
lr: 0.01
|
11 |
+
batch_size: 16
|
12 |
+
weight_decay: 0
|
13 |
+
num_epoch: 300
|
14 |
+
|
15 |
+
node_to_mask: -1
|
color_palette/config/confCNN.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size: 16
|
2 |
+
color_space: CIELab
|
3 |
+
device: cuda:1
|
4 |
+
input_color_space: RGB
|
5 |
+
input_size: 512
|
6 |
+
is_classification: false
|
7 |
+
loss_function: MSE
|
8 |
+
out_features: 1
|
9 |
+
lr: 0.05
|
10 |
+
map_outputs: true
|
11 |
+
model_name: ColorCNN_lightness_RGB_normalized
|
12 |
+
num_epoch: 150
|
13 |
+
step_size: 20
|
14 |
+
weight_decay: 0
|
15 |
+
normalize_rgb: True
|
16 |
+
normalize_cielab: True
|
color_palette/config/grid_search_conf_generator.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
lr_list = [0.01, 0.005, 0.001]
|
4 |
+
weight_decay = [0, 0.01, 0.001, 0.0001]
|
5 |
+
|
6 |
+
count = 0
|
7 |
+
for i, lr in enumerate(lr_list):
|
8 |
+
for j, wd in enumerate(weight_decay):
|
9 |
+
with open("conf"+str(count)+".yaml", "w") as file:
|
10 |
+
data = {
|
11 |
+
"model_name": "ColorGNNEmbedding_lr"+str(lr)+"_wd"+str(wd),
|
12 |
+
"data_type": "processed_rgb",
|
13 |
+
"feature_size": 1005,
|
14 |
+
|
15 |
+
"device": "cuda:2",
|
16 |
+
"lr": lr,
|
17 |
+
"batch_size": 1,
|
18 |
+
"weight_decay": wd,
|
19 |
+
"num_epoch": 150,
|
20 |
+
|
21 |
+
"dataset_root": "../destijl_dataset",
|
22 |
+
}
|
23 |
+
documents = yaml.dump(data, file)
|
24 |
+
count+=1
|
color_palette/cube_num_one_hot_LR/test_gt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
|
3 |
+
size 48128
|
color_palette/cube_num_one_hot_LR/test_preds.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:afaa3745f1579f54e69178c3d46aba4a9c23625107f65d0fb175e91e0bc6c243
|
3 |
+
size 48128
|
color_palette/cube_num_one_hot_LR/test_preds_graph.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73bc6274b537420ae855cc9c5c0feeeeafebbbe52f0ad2db96be9304760e41d8
|
3 |
+
size 4184
|
color_palette/cube_num_one_hot_LR/test_rgb_colors.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb3b9f43eca49006a8d6a6c9b3943a545673ce5bf0b1e9c7dcf8ea3c8526fd4d
|
3 |
+
size 14324
|
color_palette/cube_num_one_hot_LR_sequential/new_palettes.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aca40e932ed33182fc3cbf5248dc03c08fc4081781529d462cdede5c1b21b106
|
3 |
+
size 12128
|
color_palette/cube_num_one_hot_LR_sequential/new_palettes_purple.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7540033b68848a3a910a55cb1896055f087171ebbddea84cd7024dca6b8f3a6e
|
3 |
+
size 12128
|
color_palette/cube_num_one_hot_LR_sequential/original_palettes.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
|
3 |
+
size 12128
|
color_palette/cube_num_one_hot_LR_sequential/original_palettes_purple.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
|
3 |
+
size 12128
|
color_palette/cube_num_one_hot_LR_sequential/test_gt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
|
3 |
+
size 48128
|
color_palette/cube_num_one_hot_LR_sequential/test_preds.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:afaa3745f1579f54e69178c3d46aba4a9c23625107f65d0fb175e91e0bc6c243
|
3 |
+
size 48128
|
color_palette/dataset.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch_geometric
|
3 |
+
from torch_geometric.data import Dataset, Data
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import yaml
|
9 |
+
from color_palette.utils import *
|
10 |
+
from skimage.color import rgb2lab, lab2rgb
|
11 |
+
from color_palette.config import *
|
12 |
+
import math
|
13 |
+
|
14 |
+
# with open("config/conf.yaml", 'r') as f:
|
15 |
+
# config = yaml.load(f, Loader=yaml.FullLoader)
|
16 |
+
|
17 |
+
config = DataConfig()
|
18 |
+
|
19 |
+
data_type = config.data_type
|
20 |
+
mask_type = -1
|
21 |
+
dataset_root = config.dataset
|
22 |
+
|
23 |
+
print(f"Torch version: {torch.__version__}")
|
24 |
+
print(f"Cuda available: {torch.cuda.is_available()}")
|
25 |
+
print(f"Torch geometric version: {torch_geometric.__version__}")
|
26 |
+
|
27 |
+
class GraphDestijlDataset(Dataset):
|
28 |
+
def __init__(self, root, test=False, transform=None, pre_transform=None, cube_mapping=False, square_label=False):
|
29 |
+
"""
|
30 |
+
root = Where the dataset should be stored. This folder is split
|
31 |
+
into raw_dir (downloaded dataset) and processed_dir (processed data).
|
32 |
+
"""
|
33 |
+
self.test = test
|
34 |
+
self.sample_filenames = os.listdir(root + data_type +'/')
|
35 |
+
self.processed_data_dir = root + data_type + '/'
|
36 |
+
self.square_label = square_label
|
37 |
+
self.map_to_cube = cube_mapping
|
38 |
+
self.cube_size = 4
|
39 |
+
|
40 |
+
# If you want to use less data than the whole dataset, you can specify the range here.
|
41 |
+
# Than it loads only samples up to that sample.
|
42 |
+
self.sample_filenames = ["data_{:04d}.pt".format(idx) for idx in range(0, 1000)]
|
43 |
+
|
44 |
+
# Train test filenames.
|
45 |
+
self.train_filenames, self.test_filenames = train_test_split(self.sample_filenames,
|
46 |
+
test_size=0.2,
|
47 |
+
random_state=42)
|
48 |
+
super().__init__(root, transform, pre_transform)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def raw_file_names(self):
|
52 |
+
return "empty"
|
53 |
+
|
54 |
+
@property
|
55 |
+
def processed_file_names(self):
|
56 |
+
""" If these files are found in raw_dir, processing is skipped"""
|
57 |
+
|
58 |
+
if self.test:
|
59 |
+
return self.test_filenames
|
60 |
+
else:
|
61 |
+
return self.train_filenames
|
62 |
+
|
63 |
+
def download(self):
|
64 |
+
pass
|
65 |
+
|
66 |
+
def process(self):
|
67 |
+
pass
|
68 |
+
|
69 |
+
def return_all_same_colors(self, input_data):
|
70 |
+
indices_list = [[],[],[],[],[]]
|
71 |
+
unique_colors = torch.from_numpy(np.unique(input_data.x[:, 4:], axis=0))
|
72 |
+
all_colors = input_data.x[:, 4:]
|
73 |
+
for idx, color in enumerate(unique_colors):
|
74 |
+
for node_num, element in enumerate(all_colors):
|
75 |
+
if torch.equal(color, element):
|
76 |
+
indices_list[idx].append(node_num)
|
77 |
+
|
78 |
+
|
79 |
+
return indices_list, unique_colors
|
80 |
+
|
81 |
+
|
82 |
+
def test_train_mask(self, data):
|
83 |
+
|
84 |
+
'''
|
85 |
+
Input: graph data
|
86 |
+
|
87 |
+
Mask the color of one node. The ground truth color is the last 3 dimension of the feature vector.
|
88 |
+
Data is saved as RGB.
|
89 |
+
If you want you can convert unnormalized RGB ground truth color to Lab.
|
90 |
+
(Conversion is done using COLORMATH)
|
91 |
+
Put mask on a random node's color information by setting that color to [0, 0, 0].
|
92 |
+
|
93 |
+
Return: new_data with masked RGB colors, color_to_hide in lab, node_to_mask scalar
|
94 |
+
'''
|
95 |
+
|
96 |
+
# Take number of nodes
|
97 |
+
n_nodes = len(data.x)
|
98 |
+
if mask_type == -1:
|
99 |
+
# Chose the color to mask randomly
|
100 |
+
node_to_mask = random.randint(0, n_nodes-1)
|
101 |
+
#node_to_mask = n_nodes-1
|
102 |
+
else:
|
103 |
+
# Mask the red each time
|
104 |
+
node_to_mask = n_nodes-2
|
105 |
+
# If you chosed a folder that has processed_rgb name in it, then all the colors are stores in (0, 255) RGB.
|
106 |
+
feature_vector = data.x
|
107 |
+
# This is our target.
|
108 |
+
|
109 |
+
also_normal_values = []
|
110 |
+
|
111 |
+
for color in enumerate(feature_vector[:, -3:].clone()):
|
112 |
+
also_normal_values.append(color[1])
|
113 |
+
|
114 |
+
color_to_hide = feature_vector[node_to_mask, -3:].clone()
|
115 |
+
|
116 |
+
if self.map_to_cube:
|
117 |
+
# print("RGB")
|
118 |
+
# print(color_to_hide*255)
|
119 |
+
color_to_hide = self.cube_mapping(color_to_hide*255)
|
120 |
+
# print("CONVERSION")
|
121 |
+
# print(color_to_hide)
|
122 |
+
|
123 |
+
# Conversion to cielab. I do not use it anymore.
|
124 |
+
#color_to_hide = torch.tensor(RGB2CIELab(color_to_hide.numpy().astype(np.int32)))
|
125 |
+
#color_to_hide = torch.tensor(rgb2lab(color_to_hide.numpy().astype(np.int32)))
|
126 |
+
# Set node to mask in feature vector to zero.
|
127 |
+
feature_vector[node_to_mask, -3:] = torch.Tensor([0, 0, 0]) #torch.Tensor([0.9, 0.1, 0.1])
|
128 |
+
|
129 |
+
# Seperate square and circle labels.
|
130 |
+
|
131 |
+
if self.square_label:
|
132 |
+
|
133 |
+
add_label = torch.zeros([5,1])
|
134 |
+
add_label[4, 0] = 1
|
135 |
+
labels = feature_vector[:, :3]
|
136 |
+
labels[4, 2] = 0
|
137 |
+
labels = torch.cat((labels, add_label), dim=1)
|
138 |
+
feature_vector = torch.cat((labels, feature_vector[:, -3:]), dim=1)
|
139 |
+
|
140 |
+
# Assing the new feature vector to the graph.
|
141 |
+
|
142 |
+
new_data = data.clone()
|
143 |
+
new_data.x = feature_vector
|
144 |
+
|
145 |
+
|
146 |
+
# This code below is used if we want to apply a threshold while adding edges.
|
147 |
+
# It just removes the edges with a higher distance than the threshold.
|
148 |
+
|
149 |
+
# new_edge_weight = []
|
150 |
+
# new_edge_index = []
|
151 |
+
# for k, edge in enumerate(new_data.edge_weight):
|
152 |
+
# if edge.item() < threshold:
|
153 |
+
# new_edge_weight.append(edge.item())
|
154 |
+
# new_edge_index.append([new_data.edge_index[0][k], new_data.edge_index[1][k]])
|
155 |
+
|
156 |
+
# new_data.edge_index = torch.Tensor(new_edge_index).T
|
157 |
+
# new_data.edge_weight = torch.Tensor(new_edge_weight)
|
158 |
+
#print("New calculations")
|
159 |
+
#print(new_data.edge_index, new_data.edge_weight)
|
160 |
+
|
161 |
+
if self.map_to_cube:
|
162 |
+
cube_num_list = []
|
163 |
+
cube_colors = torch.zeros((feature_vector.shape[0], int(math.pow(self.cube_size, 3))))
|
164 |
+
for j, color in enumerate(feature_vector[:, -3:]):
|
165 |
+
if j != node_to_mask:
|
166 |
+
cube_num = self.cube_mapping(color*255)
|
167 |
+
cube_colors[j][cube_num] = 1
|
168 |
+
cube_num_list.append(cube_num)
|
169 |
+
|
170 |
+
new_data.x = torch.cat((feature_vector[:, :-3], cube_colors), dim=1)
|
171 |
+
new_data.y = cube_num_list
|
172 |
+
|
173 |
+
return new_data, color_to_hide, node_to_mask, also_normal_values
|
174 |
+
|
175 |
+
def len(self):
|
176 |
+
if self.test:
|
177 |
+
return len(self.test_filenames)
|
178 |
+
else:
|
179 |
+
return len(self.train_filenames)
|
180 |
+
|
181 |
+
def get(self, idx):
|
182 |
+
if self.test:
|
183 |
+
data = torch.load(self.processed_data_dir+self.test_filenames[idx])
|
184 |
+
new_data, target_color, node_to_mask, also_normal_values = self.test_train_mask(data)
|
185 |
+
else:
|
186 |
+
data = torch.load(self.processed_data_dir+self.train_filenames[idx])
|
187 |
+
new_data, target_color, node_to_mask, also_normal_values = self.test_train_mask(data)
|
188 |
+
return new_data, target_color, node_to_mask, also_normal_values
|
189 |
+
|
190 |
+
def cube_mapping(self, color):
|
191 |
+
intervals = np.arange(0, 256, 256//4)
|
192 |
+
cube_coordinates = []
|
193 |
+
for channel in color:
|
194 |
+
i = 0
|
195 |
+
for j, value in enumerate(intervals):
|
196 |
+
if value < channel:
|
197 |
+
i = j
|
198 |
+
cube_coordinates.append(i)
|
199 |
+
|
200 |
+
cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*self.cube_size + cube_coordinates[2]*self.cube_size*self.cube_size
|
201 |
+
return cube_num
|
202 |
+
|
203 |
+
def cube2rgb(self, cube_num):
|
204 |
+
"""
|
205 |
+
Return the start of the ranges
|
206 |
+
"""
|
207 |
+
cube_num = int(cube_num)
|
208 |
+
intervals = np.arange(0, 256, 256//4)
|
209 |
+
coor2 = cube_num // 16
|
210 |
+
coor1 = (cube_num - coor2*self.cube_size*self.cube_size) // 4
|
211 |
+
coor0 = cube_num - coor2*self.cube_size*self.cube_size - coor1*self.cube_size
|
212 |
+
return [intervals[coor0], intervals[coor1], intervals[coor2]]
|
213 |
+
|
214 |
+
if __name__ == '__main__':
|
215 |
+
dataset_obj = GraphDestijlDataset(root=dataset_root)
|
color_palette/dataset_processing.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch_geometric.data import Dataset
|
4 |
+
|
5 |
+
import os
|
6 |
+
from utils import *
|
7 |
+
|
8 |
+
from paddleocr import PaddleOCR, draw_ocr
|
9 |
+
|
10 |
+
from PIL import Image, ImageFont
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from sklearn.cluster import KMeans
|
14 |
+
from sklearn.metrics import silhouette_score
|
15 |
+
from matplotlib.colors import hsv_to_rgb, rgb_to_hsv
|
16 |
+
|
17 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
18 |
+
from model.CNN import Autoencoder
|
19 |
+
|
20 |
+
from model.graph import DesignGraph
|
21 |
+
from config import *
|
22 |
+
|
23 |
+
class ProcessedDeStijl(Dataset):
|
24 |
+
def __init__(self, data_path):
|
25 |
+
self.path = data_path
|
26 |
+
self.path_dict = {
|
27 |
+
'preview': data_path + '/00_preview/',
|
28 |
+
'background': data_path + '/01_background/',
|
29 |
+
'image': data_path + '/02_image/',
|
30 |
+
'decoration': data_path + '/03_decoration/',
|
31 |
+
'text': data_path + '/04_text/',
|
32 |
+
'theme': data_path + '/05_theme/'
|
33 |
+
}
|
34 |
+
|
35 |
+
self.data_path = data_path
|
36 |
+
self.dataset_size = len(next(os.walk(self.path_dict['preview']))[2])
|
37 |
+
self.ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False)
|
38 |
+
|
39 |
+
self.criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
|
40 |
+
self.flags = cv2.KMEANS_RANDOM_CENTERS
|
41 |
+
|
42 |
+
#self.layers = ['background', 'text', "decoration"] # Take this from config file later
|
43 |
+
self.layers = ['background', 'text', "decoration"]
|
44 |
+
|
45 |
+
self.pretrained_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
46 |
+
# self.pretrained_model = Autoencoder()
|
47 |
+
# self.pretrained_model.load_state_dict(torch.load("../CNN_models/CNNAutoencoder/weights/best.pth")["state_dict"])
|
48 |
+
|
49 |
+
def len(self):
|
50 |
+
return self.dataset_size
|
51 |
+
|
52 |
+
def get(self, idx):
|
53 |
+
'''
|
54 |
+
Return a graph object based on the information
|
55 |
+
'''
|
56 |
+
pass
|
57 |
+
|
58 |
+
######### RUNTIME EXTRACTION ###########
|
59 |
+
|
60 |
+
|
61 |
+
######### PROCESSING AND ANNOTATING THE DATASET ###########
|
62 |
+
|
63 |
+
def extract_text_bbox(self, img_path, preview_image_path):
|
64 |
+
'''
|
65 |
+
Input: path to the text image
|
66 |
+
Extract text using paddleOCR.
|
67 |
+
Crop text from bounding box.
|
68 |
+
Extract colors using Kmeans inside the bbox.
|
69 |
+
Return the dominant color and the position.
|
70 |
+
|
71 |
+
DONE: Try to combine very close lines as paragraph bbox.
|
72 |
+
If the the distance between two bbox is smaller than the bbox height and color is the same,
|
73 |
+
we can group them as paragraphs.
|
74 |
+
|
75 |
+
Return: text color palettes, dominant colors for each text and position list (as bboxes).
|
76 |
+
'''
|
77 |
+
# Parameters for KMeans.
|
78 |
+
n_colors = 3
|
79 |
+
|
80 |
+
result = self.ocr.ocr(img_path, cls=True)[0]
|
81 |
+
|
82 |
+
image = Image.open(img_path).convert('RGB')
|
83 |
+
boxes = [line[0] for line in result]
|
84 |
+
texts = [line[1][0] for line in result]
|
85 |
+
image = cv2.imread(img_path)
|
86 |
+
preview_image = cv2.imread(preview_image_path)
|
87 |
+
|
88 |
+
palettes = []
|
89 |
+
dominants = []
|
90 |
+
new_bboxes = []
|
91 |
+
|
92 |
+
# Run KMeans for each text object
|
93 |
+
for bbox in boxes:
|
94 |
+
# Crop the text area
|
95 |
+
x, y = int(bbox[0][0]), int(bbox[0][1])
|
96 |
+
z, t = int(bbox[2][0]), int(bbox[2][1])
|
97 |
+
cropped_image = image[y:t, x:z]
|
98 |
+
|
99 |
+
# Do template matching to find the places at the actual image because not every image has the same size.
|
100 |
+
method = cv2.TM_SQDIFF_NORMED
|
101 |
+
result = cv2.matchTemplate(cropped_image, preview_image, method)
|
102 |
+
mn,_,mnLoc,_ = cv2.minMaxLoc(result)
|
103 |
+
MPx,MPy = mnLoc
|
104 |
+
trows,tcols = cropped_image.shape[:2]
|
105 |
+
# --> left top, right top, right bottom, left bottom
|
106 |
+
bbox = [[MPx,MPy], [MPx+tcols, MPy], [MPx+tcols, MPy+trows], [MPx, MPy+trows]]
|
107 |
+
new_bboxes.append(bbox)
|
108 |
+
|
109 |
+
return new_bboxes, boxes, texts
|
110 |
+
|
111 |
+
def compose_paragraphs(self, text_bboxes, text_palettes):
|
112 |
+
|
113 |
+
'''
|
114 |
+
Compose text data into paragraphs.
|
115 |
+
Return: Grouped indices of detected text elements.
|
116 |
+
'''
|
117 |
+
|
118 |
+
num_text_boxes = len(text_bboxes)
|
119 |
+
if num_text_boxes == 0:
|
120 |
+
return False
|
121 |
+
composed_text_idxs = [[0]]
|
122 |
+
for i in range(num_text_boxes-1):
|
123 |
+
palette1 = text_palettes[i]
|
124 |
+
palette2 = text_palettes[i+1]
|
125 |
+
if np.array_equal(palette1, palette2):
|
126 |
+
bbox1 = text_bboxes[i]
|
127 |
+
bbox2 = text_bboxes[i+1]
|
128 |
+
height1 = bbox1[0][1] - bbox1[3][1]
|
129 |
+
height2 = bbox2[0][1] - bbox2[3][1]
|
130 |
+
if abs(bbox1[0][1]-bbox2[0][1]) <= abs(height1)+30:
|
131 |
+
if i != 0 and i not in composed_text_idxs[-1]:
|
132 |
+
composed_text_idxs.append([i])
|
133 |
+
composed_text_idxs[-1].append(i+1)
|
134 |
+
else:
|
135 |
+
if i != 0 and i not in composed_text_idxs[-1]:
|
136 |
+
composed_text_idxs.append([i])
|
137 |
+
if i == num_text_boxes-2:
|
138 |
+
composed_text_idxs.append([i+1])
|
139 |
+
else:
|
140 |
+
if i != 0 and i not in composed_text_idxs[-1]:
|
141 |
+
composed_text_idxs.append([i])
|
142 |
+
if i == (num_text_boxes-2):
|
143 |
+
composed_text_idxs.append([i+1])
|
144 |
+
|
145 |
+
return composed_text_idxs
|
146 |
+
|
147 |
+
def merge_bounding_boxes(self, composed_text_idxs, bboxes):
|
148 |
+
'''
|
149 |
+
openCV --> x: left-to-right, y: top--to-bottom
|
150 |
+
bbox coordinates --> [[256.0, 1105.0], [1027.0, 1105.0], [1027.0, 1142.0], [256.0, 1142.0]]
|
151 |
+
--> left top, right top, right bottom, left bottom
|
152 |
+
|
153 |
+
TODO: Also return color palettes for each merged box.
|
154 |
+
'''
|
155 |
+
|
156 |
+
biggest_borders = []
|
157 |
+
if len(bboxes) == 0:
|
158 |
+
return biggest_borders
|
159 |
+
for idxs in composed_text_idxs:
|
160 |
+
smallest_x = smallest_y = 10000
|
161 |
+
biggest_y = biggest_x = 0
|
162 |
+
if len(idxs) > 1:
|
163 |
+
for idx in idxs:
|
164 |
+
bbox = bboxes[idx]
|
165 |
+
bbox_smallest_x, bbox_smallest_y = np.min(bbox, axis=0)
|
166 |
+
bbox_biggest_x, bbox_biggest_y = np.max(bbox, axis=0)
|
167 |
+
|
168 |
+
if smallest_x > bbox_smallest_x:
|
169 |
+
smallest_x = bbox_smallest_x
|
170 |
+
if smallest_y > bbox_smallest_y:
|
171 |
+
smallest_y = bbox_smallest_y
|
172 |
+
if biggest_x < bbox_biggest_x:
|
173 |
+
biggest_x = bbox_biggest_x
|
174 |
+
if biggest_y < bbox_biggest_y:
|
175 |
+
biggest_y = bbox_biggest_y
|
176 |
+
|
177 |
+
biggest_border = [[smallest_x, smallest_y], [biggest_x, smallest_y], [biggest_x, biggest_y], [smallest_x, biggest_y]]
|
178 |
+
biggest_borders.append(biggest_border)
|
179 |
+
else:
|
180 |
+
biggest_borders.append(bboxes[idxs[0]])
|
181 |
+
return biggest_borders
|
182 |
+
|
183 |
+
def mini_kmeans(self, biggest_border, n_colors, image):
|
184 |
+
# for text
|
185 |
+
x, y = int(biggest_border[0][0]), int(biggest_border[0][1])
|
186 |
+
z, t = int(biggest_border[2][0]), int(biggest_border[2][1])
|
187 |
+
cropped_image = image[y:t, x:z]
|
188 |
+
pixels = np.float32(cropped_image.reshape(-1, 3))
|
189 |
+
_, labels, palette = cv2.kmeans(pixels, n_colors, None, self.criteria, 10, self.flags)
|
190 |
+
palette = np.asarray(palette, dtype=np.int64)
|
191 |
+
|
192 |
+
_, counts = np.unique(labels, return_counts=True)
|
193 |
+
color = palette[np.argmin(counts)]
|
194 |
+
return color
|
195 |
+
|
196 |
+
|
197 |
+
def extract_text_directly(self, img_path, white_bg_texts):
|
198 |
+
n_colors = 2
|
199 |
+
|
200 |
+
result = self.ocr.ocr(img_path, cls=True)[0]
|
201 |
+
|
202 |
+
image = Image.open(img_path).convert('RGB')
|
203 |
+
boxes = [line[0] for line in result]
|
204 |
+
texts = [line[1][0].replace(" ", "").lower() for line in result]
|
205 |
+
white_bg_texts = [elem.replace(" ", "").lower() for elem in white_bg_texts]
|
206 |
+
image = cv2.imread(img_path)
|
207 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
208 |
+
same_idxs = []
|
209 |
+
new_boxes = []
|
210 |
+
|
211 |
+
composed_text_palettes = []
|
212 |
+
for j, elem in enumerate(white_bg_texts):
|
213 |
+
for i, text in enumerate(texts):
|
214 |
+
if similar(elem, text) > 0.85:
|
215 |
+
new_boxes.append(boxes[i])
|
216 |
+
biggest_border = boxes[i]
|
217 |
+
composed_text_palettes.append(self.mini_kmeans(biggest_border, n_colors, image))
|
218 |
+
elif i+1 != len(texts):
|
219 |
+
if similar(elem, text + texts[i+1]) > 0.85:
|
220 |
+
# merge boxes
|
221 |
+
bboxes = [boxes[i], boxes[i+1]]
|
222 |
+
smallest_x = 1000
|
223 |
+
smallest_x = smallest_y = 10000
|
224 |
+
biggest_y = biggest_x = 0
|
225 |
+
for idx in [0, 1]:
|
226 |
+
bbox = bboxes[idx]
|
227 |
+
bbox_smallest_x, bbox_smallest_y = np.min(bbox, axis=0)
|
228 |
+
bbox_biggest_x, bbox_biggest_y = np.max(bbox, axis=0)
|
229 |
+
|
230 |
+
if smallest_x > bbox_smallest_x:
|
231 |
+
smallest_x = bbox_smallest_x
|
232 |
+
if smallest_y > bbox_smallest_y:
|
233 |
+
smallest_y = bbox_smallest_y
|
234 |
+
if biggest_x < bbox_biggest_x:
|
235 |
+
biggest_x = bbox_biggest_x
|
236 |
+
if biggest_y < bbox_biggest_y:
|
237 |
+
biggest_y = bbox_biggest_y
|
238 |
+
|
239 |
+
biggest_border = [[smallest_x, smallest_y], [biggest_x, smallest_y], [biggest_x, biggest_y], [smallest_x, biggest_y]]
|
240 |
+
new_boxes.append(biggest_border)
|
241 |
+
composed_text_palettes.append(self.mini_kmeans(biggest_border, n_colors, image))
|
242 |
+
|
243 |
+
return new_boxes, composed_text_palettes
|
244 |
+
|
245 |
+
def extract_decor_elements(self, decoration_path, preview_path):
|
246 |
+
# Determine the number of dominant colors
|
247 |
+
num_colors = 6
|
248 |
+
|
249 |
+
# Load the image
|
250 |
+
image = cv2.imread(decoration_path)
|
251 |
+
|
252 |
+
# Convert the image to the RGB color space
|
253 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
254 |
+
image2 = image.copy()
|
255 |
+
|
256 |
+
# Reshape the image to a 2D array of pixels
|
257 |
+
pixels = image.reshape(-1, 3)
|
258 |
+
|
259 |
+
# Apply K-means clustering with the determined number of colors
|
260 |
+
kmeans = KMeans(n_clusters=num_colors)
|
261 |
+
kmeans.fit(pixels)
|
262 |
+
|
263 |
+
# Get the RGB values of the dominant colors
|
264 |
+
colors = kmeans.cluster_centers_.astype(int)
|
265 |
+
|
266 |
+
# Convert the colors to the HSV color space
|
267 |
+
hsv_colors = []
|
268 |
+
|
269 |
+
for i, color in enumerate(colors):
|
270 |
+
x, y, z = color
|
271 |
+
if not (252 < x < 256 and 252 < y < 256 and 252 < z < 256):
|
272 |
+
x, y, z = rgb_to_hsv([x/255, y/255, z/255])
|
273 |
+
hsv_colors.append([x*180, y*255, z*255])
|
274 |
+
# Convert the image to the HSV color space
|
275 |
+
hsv_image = cv2.cvtColor(image2, cv2.COLOR_RGB2HSV)
|
276 |
+
|
277 |
+
# Create masks for each dominant color
|
278 |
+
masks = []
|
279 |
+
hsv_colors = np.asarray(hsv_colors, dtype=np.int32)
|
280 |
+
|
281 |
+
colors = []
|
282 |
+
for i in range(len(hsv_colors)):
|
283 |
+
|
284 |
+
h, s, v = hsv_colors[i, :]
|
285 |
+
lower_color = hsv_colors[i, :] - np.array([10, 50, 50])
|
286 |
+
upper_color = hsv_colors[i, :] + np.array([10, 255, 255])
|
287 |
+
mask = cv2.inRange(hsv_image, lower_color, upper_color)
|
288 |
+
colors.append([h,s,v])
|
289 |
+
masks.append(mask)
|
290 |
+
|
291 |
+
# Find contours in each mask
|
292 |
+
contours = []
|
293 |
+
for mask in masks:
|
294 |
+
contours_color, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
295 |
+
contours.append(contours_color)
|
296 |
+
|
297 |
+
# Draw bounding boxes around the shapes
|
298 |
+
image_with_boxes = image.copy()
|
299 |
+
bboxes = []
|
300 |
+
for i, contour_color in enumerate(contours):
|
301 |
+
for contour in contour_color:
|
302 |
+
x, y, w, h = cv2.boundingRect(contour)
|
303 |
+
# left top, right top, right bottom, left bottom
|
304 |
+
bboxes.append([[x,y], [x+w, y], [x+w, y+h], [x,y+h]])
|
305 |
+
|
306 |
+
new_bboxes = delete_too_small_bboxes(np.asarray(bboxes))
|
307 |
+
return colors, new_bboxes
|
308 |
+
|
309 |
+
def map_decoration_coordinates(self, design_text_coordinate, text_coordinate, decoration_coordinates, prev_size, text_size):
|
310 |
+
# --> [[256.0, 1105.0], [1027.0, 1105.0], [1027.0, 1142.0], [256.0, 1142.0]]
|
311 |
+
# --> left top, right top, right bottom, left bottom
|
312 |
+
|
313 |
+
prev_x, prev_y = prev_size
|
314 |
+
text_x, text_y = text_size
|
315 |
+
|
316 |
+
design_x, design_y = design_text_coordinate[0]
|
317 |
+
text_x, text_y = text_coordinate[0]
|
318 |
+
|
319 |
+
diff_x = text_x - design_x
|
320 |
+
diff_y = text_y - design_y
|
321 |
+
|
322 |
+
new_coordinates = []
|
323 |
+
for coordinate in decoration_coordinates:
|
324 |
+
new_coor = []
|
325 |
+
for elem in coordinate:
|
326 |
+
new_coor.append([elem[0]-diff_x, elem[1]-diff_y])
|
327 |
+
new_coordinates.append(new_coor)
|
328 |
+
|
329 |
+
return new_coordinates
|
330 |
+
|
331 |
+
def extract_image(self, preview_path, image_path):
|
332 |
+
'''
|
333 |
+
Use Template Matching the put a bounding box around the main image. Use it as the position.
|
334 |
+
Extract colors using KMeans.
|
335 |
+
Return: image color palettes and position list (as bboxes).
|
336 |
+
'''
|
337 |
+
|
338 |
+
preview_image = cv2.imread(preview_path)
|
339 |
+
preview_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
340 |
+
cropped_image_path = trim_image(image_path, "02_image")
|
341 |
+
image = cv2.imread(cropped_image_path)
|
342 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
343 |
+
|
344 |
+
if image.shape[0] > preview_image.shape[0]:
|
345 |
+
diff_x = image.shape[0] - preview_image.shape[0]
|
346 |
+
image = image[(diff_x//2+1):image.shape[0]-(diff_x//2+1), :]
|
347 |
+
|
348 |
+
if image.shape[1] > preview_image.shape[1]:
|
349 |
+
diff_y = image.shape[1] - preview_image.shape[1]
|
350 |
+
image = image[:, (diff_y//2+1):image.shape[1]-(diff_y//2+1)]
|
351 |
+
|
352 |
+
method = cv2.TM_SQDIFF_NORMED
|
353 |
+
result = cv2.matchTemplate(image, preview_image, method)
|
354 |
+
mn,_,mnLoc,_ = cv2.minMaxLoc(result)
|
355 |
+
MPx,MPy = mnLoc
|
356 |
+
trows,tcols = image.shape[:2]
|
357 |
+
bbox = [[MPx,MPy], [MPx+tcols,MPy+trows]]
|
358 |
+
cropped_image = preview_image[MPx:MPx+tcols, MPy:MPy+trows]
|
359 |
+
|
360 |
+
pixels = np.float32(image.reshape(-1, 3))
|
361 |
+
|
362 |
+
n_colors = 6
|
363 |
+
|
364 |
+
_, labels, palette = cv2.kmeans(pixels, n_colors, None, self.criteria, 10, self.flags)
|
365 |
+
palette = np.asarray(palette, dtype=np.int64)
|
366 |
+
|
367 |
+
return [bbox], palette
|
368 |
+
|
369 |
+
def annotate_dataset(self):
|
370 |
+
|
371 |
+
for idx in range(388, self.dataset_size):
|
372 |
+
print("CURRENTLY AT: ", idx)
|
373 |
+
path_idx = "{:04d}".format(idx)
|
374 |
+
preview = self.path_dict['preview'] + path_idx + '.png'
|
375 |
+
decoration = self.path_dict['decoration'] + path_idx + '.png'
|
376 |
+
image = self.path_dict['image'] + path_idx + '.png'
|
377 |
+
text = self.path_dict['text'] + path_idx + '.png'
|
378 |
+
|
379 |
+
text_bboxes, white_bg_text_boxes, texts = self.extract_text_bbox(text, preview)
|
380 |
+
text_bboxes_from_design, composed_text_palettes = self.extract_text_directly(preview, texts)
|
381 |
+
composed_text_idxs = self.compose_paragraphs(text_bboxes_from_design, composed_text_palettes)
|
382 |
+
merged_bboxes = []
|
383 |
+
if composed_text_idxs != False:
|
384 |
+
merged_bboxes = self.merge_bounding_boxes(composed_text_idxs, text_bboxes_from_design)
|
385 |
+
image_bboxes, image_palette = self.extract_image(preview, image)
|
386 |
+
|
387 |
+
#decoration_hsv_xpalettes, decoration_bboxes = self.extract_decor_elements(decoration, preview)
|
388 |
+
# image_prev = cv2.imread(preview)
|
389 |
+
# image_text = cv2.imread(text)
|
390 |
+
#mapped_decoration_bboxes = self.map_decoration_coordinates(text_bboxes_from_design[0], white_bg_text_boxes[0], decoration_bboxes, (image_prev.shape[0], image_prev.shape[1]), (image_text.shape[0], image_text.shape[1]))
|
391 |
+
|
392 |
+
#create_xml("../destijl_dataset/xmls/03_decoration", path_idx+".xml", mapped_decoration_bboxes)
|
393 |
+
if len(merged_bboxes) == 0:
|
394 |
+
create_xml(self.data_path+"/xmls/04_text", path_idx+".xml", [[[0,0],[0,0],[0,0],[0,0]]])
|
395 |
+
else:
|
396 |
+
create_xml(self.data_path+"/xmls/04_text", path_idx+".xml", merged_bboxes)
|
397 |
+
create_xml(self.data_path+"/xmls/02_image", path_idx+".xml", image_bboxes)
|
398 |
+
|
399 |
+
def process_dataset(self, idx):
|
400 |
+
'''
|
401 |
+
Process each node. Construct graph features and save the features as pt files.
|
402 |
+
This code should be used after we have an annotated dataset.
|
403 |
+
'''
|
404 |
+
path_idx = "{:04d}".format(idx)
|
405 |
+
|
406 |
+
img_path_dict = {
|
407 |
+
'preview': self.data_path + '/00_preview/' + path_idx + '.png',
|
408 |
+
'background': self.data_path + '/01_background/' + path_idx + '.png',
|
409 |
+
'image': self.data_path + '/02_image/' + path_idx + '.png',
|
410 |
+
'text': self.data_path + '/04_text/' + path_idx + '.png',
|
411 |
+
'decoration': self.data_path + '/03_decoration/' + path_idx + '.png',
|
412 |
+
}
|
413 |
+
|
414 |
+
annotation_path_dict = {
|
415 |
+
'preview': self.data_path + '/xmls' +'/00_preview/' + path_idx + '.xml',
|
416 |
+
'image': self.data_path + '/xmls' + '/02_image/' + path_idx + '.xml',
|
417 |
+
'text': self.data_path + '/xmls' + '/04_text/' + path_idx + '.xml',
|
418 |
+
'decoration': self.data_path + '/xmls' + '/03_decoration/' + path_idx + '.xml',
|
419 |
+
}
|
420 |
+
|
421 |
+
all_bboxes = {
|
422 |
+
'image':[],
|
423 |
+
'background':[],
|
424 |
+
"decoration":[],
|
425 |
+
'text':[]
|
426 |
+
}
|
427 |
+
all_images = {
|
428 |
+
'image':[],
|
429 |
+
'background':[],
|
430 |
+
"decoration":[],
|
431 |
+
'text':[]
|
432 |
+
}
|
433 |
+
|
434 |
+
# For each layer:
|
435 |
+
# * save all bounding boxes to all_bboxes
|
436 |
+
# * save all paths of images in which we extract the objects from --> to all_images
|
437 |
+
# * We generally extract all images from the preview image so that path is preview image path
|
438 |
+
# * We save this information to use in DesignGraph. It extracts colors from bounding boxes
|
439 |
+
# using the image we want to extract them from.
|
440 |
+
|
441 |
+
for i, layer in enumerate(self.layers):
|
442 |
+
# Check what is the layer, save information accordingly to dictionaries
|
443 |
+
if layer == "background":
|
444 |
+
# load image as CV image
|
445 |
+
self.preview_img = cv2.imread(img_path_dict[layer])
|
446 |
+
# save the layer image path
|
447 |
+
img = img_path_dict[layer]
|
448 |
+
all_images[layer] = img
|
449 |
+
# since it is background, just add a trivial bbox. This is not used
|
450 |
+
all_bboxes[layer] = [[[0, 0], [self.preview_img.shape[0], 0], [self.preview_img.shape[0], self.preview_img.shape[1]], [0, self.preview_img.shape[1]]]]
|
451 |
+
else:
|
452 |
+
if layer == 'text':
|
453 |
+
# get the preview path since we extract text directly from preview image
|
454 |
+
img_path = img_path_dict['preview']
|
455 |
+
# get annotations from xml
|
456 |
+
filename, bboxes = VOC2bbox(annotation_path_dict[layer])
|
457 |
+
# assign bounding boxes and image path to extract colors later
|
458 |
+
all_bboxes[layer] = bboxes
|
459 |
+
all_images[layer] = img_path
|
460 |
+
elif layer == 'image' or layer == "decoration":
|
461 |
+
# same logic is applied as text
|
462 |
+
img_path = img_path_dict['preview']
|
463 |
+
self.img_img = cv2.imread(img_path)
|
464 |
+
filename, bboxes = VOC2bbox(annotation_path_dict[layer])
|
465 |
+
"""
|
466 |
+
This comment below is for checking whether the annotation boxes work for
|
467 |
+
each bounding box. It saves the annotated image.
|
468 |
+
"""
|
469 |
+
# for k, box in enumerate(bboxes):
|
470 |
+
# im = cv2.imread("../destijl_dataset/00_preview/"+path_idx+".png")
|
471 |
+
# # [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
|
472 |
+
# xmin = box[0][0]
|
473 |
+
# xmax = box[1][0]
|
474 |
+
# ymin = box[0][1]
|
475 |
+
# ymax = box[2][1]
|
476 |
+
# print(xmin, xmax, ymin, ymax)
|
477 |
+
# cv2.rectangle(im,(xmin, ymin),(xmax, ymax),(255,0,0),2)
|
478 |
+
# cv2.imwrite("check_bboxes"+str(k)+".jpg", im)
|
479 |
+
all_bboxes[layer] = bboxes
|
480 |
+
all_images[layer] = img_path
|
481 |
+
|
482 |
+
|
483 |
+
# Design graph constructs the graph object and saves it as a pt file.
|
484 |
+
design_graph = DesignGraph(self.pretrained_model, all_images, all_bboxes, self.layers, img_path_dict['preview'], idx)
|
485 |
+
return design_graph.get_all_colors_in_design()
|
486 |
+
|
487 |
+
def trial(self):
|
488 |
+
# Save the samples in the range (0, n)
|
489 |
+
for i in range(0, 3100):
|
490 |
+
print("Sample: ", i)
|
491 |
+
all_colors = self.process_dataset(i)
|
492 |
+
for nested_list in all_colors:
|
493 |
+
color = nested_list[0]
|
494 |
+
# Fix if has unnecessary extra dimension
|
495 |
+
if len(color.shape) == 2:
|
496 |
+
color = color[0]
|
497 |
+
color = color.tolist()
|
498 |
+
|
499 |
+
if __name__ == "__main__":
|
500 |
+
config = DataConfig()
|
501 |
+
dataset_root = config.dataset
|
502 |
+
dataset = ProcessedDeStijl(data_path=dataset_root)
|
503 |
+
dataset.trial()
|
504 |
+
|
505 |
+
|
color_palette/deneme.png
ADDED
![]() |
color_palette/deneme.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from dataset import *
|
4 |
+
from torch_geometric.loader import DataLoader
|
5 |
+
from config import DataConfig
|
6 |
+
|
7 |
+
config = DataConfig()
|
8 |
+
model_name = config.model_name
|
9 |
+
|
10 |
+
test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
|
11 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
12 |
+
num_of_plots = len(test_loader)
|
13 |
+
|
14 |
+
def my_palplot(pal, size=1, ax=None):
|
15 |
+
"""Plot the values in a color palette as a horizontal array.
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
pal : sequence of matplotlib colors
|
19 |
+
colors, i.e. as returned by seaborn.color_palette()
|
20 |
+
size :
|
21 |
+
scaling factor for size of plot
|
22 |
+
ax :
|
23 |
+
an existing axes to use
|
24 |
+
"""
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import matplotlib as mpl
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
import matplotlib.ticker as ticker
|
30 |
+
|
31 |
+
n = len(pal)
|
32 |
+
if pal[0][0] > 1:
|
33 |
+
pal = np.array(pal) / 255
|
34 |
+
ax.imshow(np.arange(n).reshape(1, n),
|
35 |
+
cmap=mpl.colors.ListedColormap(list(pal)),
|
36 |
+
interpolation="nearest", aspect="auto")
|
37 |
+
ax.set_xticks(np.arange(n) - .5)
|
38 |
+
ax.set_yticks([-.5, .5])
|
39 |
+
# Ensure nice border between colors
|
40 |
+
ax.set_xticklabels(["" for _ in range(n)])
|
41 |
+
# The proper way to set no ticks
|
42 |
+
ax.yaxis.set_major_locator(ticker.NullLocator())
|
43 |
+
|
44 |
+
targets = np.load("targets.npy", allow_pickle=True)
|
45 |
+
preds = np.load("preds.npy", allow_pickle=True)
|
46 |
+
top_k_preds = np.squeeze(np.load("top_k_preds.npy", allow_pickle=True))
|
47 |
+
node_to_masks = np.load("node_to_mask.npy", allow_pickle=True)
|
48 |
+
|
49 |
+
rows = num_of_plots//3 + 1
|
50 |
+
cols = 3
|
51 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
52 |
+
|
53 |
+
column_titles = [" Prediction | Target " for i in range(cols)]
|
54 |
+
for ax, col in zip(ax_array[0], column_titles):
|
55 |
+
ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
|
56 |
+
|
57 |
+
palettes = []
|
58 |
+
for i in range(targets.shape[0]):
|
59 |
+
target = targets[i]
|
60 |
+
pred = preds[i]
|
61 |
+
top_k_pred = top_k_preds[i]
|
62 |
+
|
63 |
+
print(target.shape, pred.shape, top_k_pred.shape)
|
64 |
+
|
65 |
+
palette = np.concatenate((top_k_pred[1:], np.atleast_1d(pred), np.atleast_1d(target)))
|
66 |
+
palettes.append(palette)
|
67 |
+
ax = plt.subplot(rows, cols, i+1)
|
68 |
+
|
69 |
+
rgb_palette = []
|
70 |
+
for color in palette:
|
71 |
+
rgb_value = test_dataset.cube2rgb(color)
|
72 |
+
rgb_palette.append(rgb_value)
|
73 |
+
|
74 |
+
my_palplot(rgb_palette, ax=ax)
|
75 |
+
|
76 |
+
if i == num_of_plots - 1:
|
77 |
+
print("saviing")
|
78 |
+
plt.savefig("../models/"+model_name+"/top_k.jpg")
|
79 |
+
|
80 |
+
|
81 |
+
unique_values = np.unique(node_to_masks)
|
82 |
+
palettes = np.array(palettes)
|
83 |
+
each_node_pred = []
|
84 |
+
each_node_target = []
|
85 |
+
for value in unique_values:
|
86 |
+
indices = np.where(node_to_masks == value)[0]
|
87 |
+
each_node_pred.append(palettes[indices])
|
88 |
+
each_node_target.append(targets[indices])
|
89 |
+
|
90 |
+
targets_repeated = np.repeat(np.expand_dims(targets, axis=1), top_k_preds.shape[1], axis=1)
|
91 |
+
total = np.sum(palettes[:, :-1] == targets_repeated)
|
92 |
+
print("Total accuracy: ", total/600)
|
93 |
+
|
94 |
+
for j, value in enumerate(unique_values):
|
95 |
+
|
96 |
+
targets_repeated = np.repeat(np.expand_dims(each_node_target[j], axis=1), top_k_preds.shape[1], axis=1)
|
97 |
+
total = np.sum(each_node_pred[j][:, :-1] == targets_repeated)
|
98 |
+
print("Accuracy of node " + str(value) + ": ", total)
|
99 |
+
|
100 |
+
plt.close()
|
101 |
+
plt.hist(preds, bins=64)
|
102 |
+
plt.savefig("hist.jpg")
|
103 |
+
|
104 |
+
plt.close()
|
105 |
+
plt.hist(targets, bins=64)
|
106 |
+
plt.savefig("hist_gt.jpg")
|
107 |
+
|
color_palette/denemeler.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
color_palette/dist.png
ADDED
![]() |
color_palette/evaluate.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch_geometric.loader import DataLoader
|
4 |
+
|
5 |
+
from dataset import GraphDestijlDataset
|
6 |
+
import yaml
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import seaborn as sns
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from utils import *
|
12 |
+
from config import *
|
13 |
+
|
14 |
+
######################## Set Parameters ########################
|
15 |
+
|
16 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
17 |
+
|
18 |
+
# you can specify the config file you want to provide
|
19 |
+
# parser = argparse.ArgumentParser()
|
20 |
+
# parser.add_argument("--config_file", type=str, default="config/conf.yaml", help="Path to the config file.")
|
21 |
+
# args = parser.parse_args()
|
22 |
+
# config_file = args.config_file
|
23 |
+
|
24 |
+
# with open(config_file, 'r') as f:
|
25 |
+
# config = yaml.load(f, Loader=yaml.FullLoader)
|
26 |
+
|
27 |
+
config = DataConfig()
|
28 |
+
|
29 |
+
data_type = config.data_type
|
30 |
+
model_name = config.model_name
|
31 |
+
device = config.device
|
32 |
+
feature_size = config.feature_size
|
33 |
+
loss_function = config.loss_function
|
34 |
+
dataset_root = config.dataset
|
35 |
+
our_node_to_mask = config.node_to_mask
|
36 |
+
|
37 |
+
model_weight_path = "../models/" + model_name + "/weights/best.pth"
|
38 |
+
|
39 |
+
######################## Model ########################
|
40 |
+
|
41 |
+
# Prepare dataset
|
42 |
+
# Set test=True for testing on the test set. Otherwise it tests on the train set.
|
43 |
+
test_dataset = GraphDestijlDataset(root=dataset_root, test=True)
|
44 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
45 |
+
num_of_plots = len(test_loader)
|
46 |
+
|
47 |
+
# Model is selected according to name you provide
|
48 |
+
model = model_switch(model_name, feature_size).to(device)
|
49 |
+
model.load_state_dict(torch.load(model_weight_path)["state_dict"])
|
50 |
+
criterion = nn.MSELoss()
|
51 |
+
######################## Helper Functions ########################
|
52 |
+
|
53 |
+
# Skip black flag was used to eliminate examples that includes black colors.
|
54 |
+
# Always set to False for default experiments
|
55 |
+
|
56 |
+
skip_black_flag=False
|
57 |
+
def test(data, target_color, node_to_mask):
|
58 |
+
|
59 |
+
model.eval()
|
60 |
+
out = model(data.x, data.edge_index.long(), data.edge_attr)
|
61 |
+
|
62 |
+
if skip_black_flag:
|
63 |
+
for color in data.y:
|
64 |
+
a, b, c = color
|
65 |
+
if 0 <= a <= 30 and 0 <= b <= 30 and 0 <= c <= 30:
|
66 |
+
return None, None
|
67 |
+
else:
|
68 |
+
loss = colormath_CIE2000(out[node_to_mask, :][0], target_color[0])
|
69 |
+
else:
|
70 |
+
# Loss is only in MSE now.
|
71 |
+
loss = criterion(out[node_to_mask, :][0], target_color[0]/255)
|
72 |
+
return loss, out
|
73 |
+
|
74 |
+
def my_palplot(pal, size=1, ax=None):
|
75 |
+
"""Plot the values in a color palette as a horizontal array.
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
pal : sequence of matplotlib colors
|
79 |
+
colors, i.e. as returned by seaborn.color_palette()
|
80 |
+
size :
|
81 |
+
scaling factor for size of plot
|
82 |
+
ax :
|
83 |
+
an existing axes to use
|
84 |
+
"""
|
85 |
+
|
86 |
+
import numpy as np
|
87 |
+
import matplotlib as mpl
|
88 |
+
import matplotlib.pyplot as plt
|
89 |
+
import matplotlib.ticker as ticker
|
90 |
+
|
91 |
+
n = len(pal)
|
92 |
+
if ax is None:
|
93 |
+
f, ax = plt.subplots(1, 1, figsize=(n * size, size))
|
94 |
+
ax.imshow(np.arange(n).reshape(1, n),
|
95 |
+
cmap=mpl.colors.ListedColormap(list(pal)),
|
96 |
+
interpolation="nearest", aspect="auto")
|
97 |
+
ax.set_xticks(np.arange(n) - .5)
|
98 |
+
ax.set_yticks([-.5, .5])
|
99 |
+
# Ensure nice border between colors
|
100 |
+
ax.set_xticklabels(["" for _ in range(n)])
|
101 |
+
# The proper way to set no ticks
|
102 |
+
ax.yaxis.set_major_locator(ticker.NullLocator())
|
103 |
+
|
104 |
+
# Config for plot
|
105 |
+
rows = num_of_plots//3 + 1
|
106 |
+
cols = 3
|
107 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
108 |
+
|
109 |
+
column_titles = [" Prediction | Target " for i in range(cols)]
|
110 |
+
for ax, col in zip(ax_array[0], column_titles):
|
111 |
+
ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
|
112 |
+
|
113 |
+
fig.suptitle(model_name+" Test Palettes", fontsize=100)
|
114 |
+
|
115 |
+
# Code for evaluation loop
|
116 |
+
plot_count = 0
|
117 |
+
val_losses = []
|
118 |
+
palettes = []
|
119 |
+
preds = []
|
120 |
+
targets = []
|
121 |
+
count = 0
|
122 |
+
for i, (input_data, target_color, node_to_mask) in enumerate(test_loader):
|
123 |
+
loss, out = test(input_data.to(device), target_color.to(device), node_to_mask)
|
124 |
+
if loss != None:
|
125 |
+
val_losses.append(loss.item())
|
126 |
+
|
127 |
+
# Get predicton and other colors in the palette
|
128 |
+
ax = plt.subplot(rows, cols, plot_count+1)
|
129 |
+
|
130 |
+
# Get prediction for a masked node
|
131 |
+
prediction = out[node_to_mask, :]
|
132 |
+
print("which node: ", node_to_mask, "count: ", count)
|
133 |
+
preds.append(prediction.detach().cpu()[0])
|
134 |
+
targets.append(target_color.detach().cpu()[0])
|
135 |
+
# Concat unmasked colors with prediction and ground truth
|
136 |
+
other_colors = input_data.y.clone()
|
137 |
+
other_colors = torch.cat([other_colors[0:node_to_mask, :], other_colors[node_to_mask+1:, :]])
|
138 |
+
#print(other_colors[0:node_to_mask, :].shape, other_colors[node_to_mask+1:, :].shape, node_to_mask)
|
139 |
+
other_colors = other_colors.type(torch.float32).detach().cpu().numpy()/255
|
140 |
+
# Normalize since they are in (0, 255) range.
|
141 |
+
# other_colors /= 255
|
142 |
+
|
143 |
+
if loss_function == "CIELab":
|
144 |
+
palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
|
145 |
+
else:
|
146 |
+
# Concat palettes. All of them are between (0, 1)
|
147 |
+
palette = np.clip(np.concatenate([other_colors, prediction.detach().cpu().numpy(), target_color.detach().cpu().numpy()]), a_min=0, a_max=1)
|
148 |
+
# I commented out codes related to calculating results in CIELab
|
149 |
+
|
150 |
+
# if "embedding" in model_name.lower():
|
151 |
+
# other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
|
152 |
+
# other_colors /= 255
|
153 |
+
# palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
|
154 |
+
# else:
|
155 |
+
# current_palette = torch.cat([other_colors, prediction, target_color.to(device)]).type(torch.float32).detach().cpu().numpy()
|
156 |
+
# palette = CIELab2RGB(current_palette)
|
157 |
+
|
158 |
+
# Save all the palettes to use it for distribution histograms.
|
159 |
+
palettes.append(prediction.detach().tolist()[0])
|
160 |
+
my_palplot(palette, ax=ax)
|
161 |
+
else:
|
162 |
+
print("none")
|
163 |
+
|
164 |
+
plot_count+=1
|
165 |
+
print(plot_count)
|
166 |
+
if i == num_of_plots-1:
|
167 |
+
path = "../models/"+model_name
|
168 |
+
if not os.path.exists(path):
|
169 |
+
os.mkdir(path)
|
170 |
+
|
171 |
+
if our_node_to_mask == -1:
|
172 |
+
print("hello")
|
173 |
+
plt.savefig(path+"/palettes.jpg")
|
174 |
+
else:
|
175 |
+
print("why red")
|
176 |
+
plt.savefig(path+"/palettes_only_blue.jpg")
|
177 |
+
plt.close()
|
178 |
+
|
179 |
+
# This is for checking prediction distribution
|
180 |
+
# It is saved as a histogram.
|
181 |
+
#check_distributions(palettes)
|
182 |
+
|
183 |
+
criterion = nn.MSELoss()
|
184 |
+
stacked_pred = torch.stack(preds)
|
185 |
+
stacked_target = torch.stack(targets)
|
186 |
+
random_results = np.random.random(size=(stacked_pred.shape[0], stacked_pred.shape[1]))
|
187 |
+
print(criterion(stacked_pred, stacked_target))
|
188 |
+
print(criterion(torch.Tensor(random_results), stacked_target))
|
color_palette/evaluate_CNN.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
import argparse
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
import seaborn as sns
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from utils import *
|
12 |
+
from model.CNN import *
|
13 |
+
from cnn_dataset import *
|
14 |
+
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
######################## Set Parameters ########################
|
18 |
+
|
19 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
20 |
+
|
21 |
+
device = "cuda:1"
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--model_name", type=str, default="ColorCNN", help="Give the name of the model.")
|
25 |
+
args = parser.parse_args()
|
26 |
+
model_name = args.model_name
|
27 |
+
|
28 |
+
config_file = "../CNN_models/"+model_name+"/conf.yaml"
|
29 |
+
|
30 |
+
with open(config_file, 'r') as f:
|
31 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
32 |
+
|
33 |
+
## Basic Training Parameters ##
|
34 |
+
model_name = config["model_name"]
|
35 |
+
device = config["device"]
|
36 |
+
|
37 |
+
## Neural Network Parameters ##
|
38 |
+
loss_function = config["loss_function"]
|
39 |
+
out_features = config["out_features"]
|
40 |
+
color_space = config["color_space"]
|
41 |
+
input_color_space = config["input_color_space"]
|
42 |
+
is_classification = config["is_classification"]
|
43 |
+
input_size = config["input_size"]
|
44 |
+
normalize_rgb = config["normalize_rgb"]
|
45 |
+
normalize_cielab = config["normalize_cielab"]
|
46 |
+
|
47 |
+
model_weight_path = "../CNN_models/" + model_name + "/weights/best.pth"
|
48 |
+
|
49 |
+
if out_features == 1:
|
50 |
+
out_type = "Lightness"
|
51 |
+
else:
|
52 |
+
out_type = "Color"
|
53 |
+
|
54 |
+
print("Evaluating for the model: ", model_name, "\n",
|
55 |
+
"Loss function: ", loss_function, "\n",
|
56 |
+
"Output Color Space: ", color_space, "\n",
|
57 |
+
"Color or Lightness?: ", out_type, "\n",
|
58 |
+
"Device: ", device, "\n")
|
59 |
+
######################## Model ########################
|
60 |
+
|
61 |
+
transform = transforms.Compose([
|
62 |
+
transforms.Resize((input_size, input_size)),
|
63 |
+
#transforms.Normalize((255,), (255,))
|
64 |
+
])
|
65 |
+
|
66 |
+
test_dataset = PreviewDataset(transform=transform,
|
67 |
+
color_space=color_space,
|
68 |
+
input_color_space=input_color_space,
|
69 |
+
normalize_rgb=normalize_rgb,
|
70 |
+
normalize_cielab=normalize_cielab,
|
71 |
+
test=True)
|
72 |
+
|
73 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
74 |
+
num_of_plots = len(test_loader)
|
75 |
+
|
76 |
+
model = model_switch_CNN(model_name, out_features).to(device)
|
77 |
+
model.load_state_dict(torch.load(model_weight_path)["state_dict"])
|
78 |
+
|
79 |
+
if loss_function == "MSE":
|
80 |
+
criterion = nn.MSELoss()
|
81 |
+
elif loss_function == "Cross-Entropy":
|
82 |
+
criterion = nn.CrossEntropyLoss()
|
83 |
+
elif loss_function == "MAE":
|
84 |
+
criterion = nn.L1Loss()
|
85 |
+
|
86 |
+
######################## Helper Functions ########################
|
87 |
+
def test(data, color):
|
88 |
+
model.eval()
|
89 |
+
out = model(data)
|
90 |
+
|
91 |
+
if out_features == 1:
|
92 |
+
if loss_function != "CIELab":
|
93 |
+
loss = criterion(out[0][0], color[0][0])
|
94 |
+
else:
|
95 |
+
loss = colormath_CIE2000(out[0][0], color[0][0])
|
96 |
+
else:
|
97 |
+
if loss_function != "CIELab":
|
98 |
+
loss = criterion(out, color)
|
99 |
+
else:
|
100 |
+
loss = colormath_CIE2000(out, color)
|
101 |
+
return loss, out
|
102 |
+
|
103 |
+
|
104 |
+
def my_palplot(pal, size=1, ax=None):
|
105 |
+
"""Plot the values in a color palette as a horizontal array.
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
pal : sequence of matplotlib colors
|
109 |
+
colors, i.e. as returned by seaborn.color_palette()
|
110 |
+
size :
|
111 |
+
scaling factor for size of plot
|
112 |
+
ax :
|
113 |
+
an existing axes to use
|
114 |
+
"""
|
115 |
+
|
116 |
+
import numpy as np
|
117 |
+
import matplotlib as mpl
|
118 |
+
import matplotlib.pyplot as plt
|
119 |
+
import matplotlib.ticker as ticker
|
120 |
+
|
121 |
+
n = len(pal)
|
122 |
+
if ax is None:
|
123 |
+
f, ax = plt.subplots(1, 1, figsize=(n * size, size))
|
124 |
+
ax.imshow(np.arange(n).reshape(1, n),
|
125 |
+
cmap=mpl.colors.ListedColormap(list(pal)),
|
126 |
+
interpolation="nearest", aspect="auto")
|
127 |
+
ax.set_xticks(np.arange(n) - .5)
|
128 |
+
ax.set_yticks([-.5, .5])
|
129 |
+
# Ensure nice border between colors
|
130 |
+
ax.set_xticklabels(["" for _ in range(n)])
|
131 |
+
# The proper way to set no ticks
|
132 |
+
ax.yaxis.set_major_locator(ticker.NullLocator())
|
133 |
+
|
134 |
+
rows = num_of_plots//3 + 1
|
135 |
+
cols = 3
|
136 |
+
|
137 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
138 |
+
column_titles = ["Prediction Target" for i in range(cols)]
|
139 |
+
for ax, col in zip(ax_array[0], column_titles):
|
140 |
+
ax.set_title(col, fontdict={'fontsize': 45, 'fontweight': 'medium'})
|
141 |
+
|
142 |
+
fig.suptitle(model_name+" Test Palettes", fontsize=100)
|
143 |
+
|
144 |
+
plot_count = 0
|
145 |
+
val_losses = []
|
146 |
+
|
147 |
+
outputs = []
|
148 |
+
target_colors = []
|
149 |
+
|
150 |
+
for i, (input_data, target_color) in enumerate(test_loader):
|
151 |
+
loss, out = test(input_data.to(device), target_color.to(device))
|
152 |
+
val_losses.append(loss.item())
|
153 |
+
# Get predicton and other colors in the palette
|
154 |
+
ax = plt.subplot(rows, cols, plot_count+1)
|
155 |
+
|
156 |
+
if color_space == "CIELab":
|
157 |
+
out = out.detach().cpu().numpy()
|
158 |
+
out = np.append(out, [[30.0, 30.0]], axis=1)
|
159 |
+
target_color = np.array([[target_color.detach().cpu().numpy()[0][0], 30.0, 30.0]])
|
160 |
+
palette = np.clip(np.concatenate([CIELab2RGB(out), CIELab2RGB(target_color)]), a_min=0, a_max=1)
|
161 |
+
else:
|
162 |
+
palette = np.clip(np.concatenate([out.detach().cpu().numpy(), target_color/255]), a_min=0, a_max=1)
|
163 |
+
outputs.append(out)
|
164 |
+
target_colors.append(target_color)
|
165 |
+
|
166 |
+
my_palplot(palette, ax=ax)
|
167 |
+
|
168 |
+
plot_count+=1
|
169 |
+
|
170 |
+
if i == num_of_plots-1:
|
171 |
+
path = "../CNN_models/"+model_name
|
172 |
+
if not os.path.exists(path):
|
173 |
+
os.mkdir(path)
|
174 |
+
plt.savefig(path+"/palettes.jpg")
|
175 |
+
plt.close()
|
176 |
+
|
177 |
+
cielab_dict = {'Output': outputs, 'Targets': target_colors}
|
178 |
+
df = pd.DataFrame(data=cielab_dict)
|
179 |
+
|
180 |
+
#df.to_csv("trainset_predictions.csv")
|
color_palette/evaluate_classification.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch_geometric.loader import DataLoader
|
4 |
+
|
5 |
+
from dataset import GraphDestijlDataset
|
6 |
+
import yaml
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import seaborn as sns
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from utils import *
|
12 |
+
from config import *
|
13 |
+
|
14 |
+
from model.GNN import *
|
15 |
+
|
16 |
+
######################## Set Parameters ########################
|
17 |
+
|
18 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
19 |
+
|
20 |
+
# you can specify the config file you want to provide
|
21 |
+
# parser = argparse.ArgumentParser()
|
22 |
+
# parser.add_argument("--config_file", type=str, default="config/conf.yaml", help="Path to the config file.")
|
23 |
+
# args = parser.parse_args()
|
24 |
+
# config_file = args.config_file
|
25 |
+
|
26 |
+
# with open(config_file, 'r') as f:
|
27 |
+
# config = yaml.load(f, Loader=yaml.FullLoader)
|
28 |
+
|
29 |
+
config = DataConfig()
|
30 |
+
|
31 |
+
data_type = config.data_type
|
32 |
+
model_name = config.model_name
|
33 |
+
device = config.device
|
34 |
+
feature_size = config.feature_size
|
35 |
+
loss_function = config.loss_function
|
36 |
+
dataset_root = config.dataset
|
37 |
+
our_node_to_mask = config.node_to_mask
|
38 |
+
|
39 |
+
model_weight_path = "../models/" + model_name + "/weights/best.pth"
|
40 |
+
|
41 |
+
######################## Model ########################
|
42 |
+
|
43 |
+
# Prepare dataset
|
44 |
+
# Set test=True for testing on the test set. Otherwise it tests on the train set.
|
45 |
+
test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
|
46 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
47 |
+
num_of_plots = len(test_loader)
|
48 |
+
|
49 |
+
# Model is selected according to name you provide
|
50 |
+
model = ColorAttentionClassification(feature_size).to(device)
|
51 |
+
model.load_state_dict(torch.load(model_weight_path)["state_dict"])
|
52 |
+
criterion = nn.CrossEntropyLoss()
|
53 |
+
######################## Helper Functions ########################
|
54 |
+
|
55 |
+
# Skip black flag was used to eliminate examples that includes black colors.
|
56 |
+
# Always set to False for default experiments
|
57 |
+
|
58 |
+
skip_black_flag=False
|
59 |
+
def test(data, target_color, node_to_mask):
|
60 |
+
|
61 |
+
model.eval()
|
62 |
+
out = model(data.x, data.edge_index.long(), data.edge_attr)
|
63 |
+
loss = criterion(out[node_to_mask, :], target_color)
|
64 |
+
return loss, out
|
65 |
+
|
66 |
+
def my_palplot(pal, pred, tar, size=1, ax=None):
|
67 |
+
"""Plot the values in a color palette as a horizontal array.
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
pal : sequence of matplotlib colors
|
71 |
+
colors, i.e. as returned by seaborn.color_palette()
|
72 |
+
size :
|
73 |
+
scaling factor for size of plot
|
74 |
+
ax :
|
75 |
+
an existing axes to use
|
76 |
+
"""
|
77 |
+
|
78 |
+
import numpy as np
|
79 |
+
import matplotlib as mpl
|
80 |
+
import matplotlib.pyplot as plt
|
81 |
+
import matplotlib.ticker as ticker
|
82 |
+
|
83 |
+
n = len(pal)
|
84 |
+
if pal[0][0] > 1:
|
85 |
+
pal = np.array(pal) / 255
|
86 |
+
if ax is None:
|
87 |
+
f, ax = plt.subplots(1, 1, figsize=(n * size, size))
|
88 |
+
ax.imshow(np.arange(n).reshape(1, n),
|
89 |
+
cmap=mpl.colors.ListedColormap(list(pal)),
|
90 |
+
interpolation="nearest", aspect="auto")
|
91 |
+
ax.set_xticks(np.arange(n) - .5)
|
92 |
+
ax.set_yticks([-.5, .5])
|
93 |
+
rgb_pred = test_dataset.cube2rgb(pred)
|
94 |
+
rgb_pred_str = " ".join(str(x) for x in rgb_pred)
|
95 |
+
rgb_tar = test_dataset.cube2rgb(tar)
|
96 |
+
rgb_tar_str = " ".join(str(x) for x in rgb_tar)
|
97 |
+
ax.set_ylabel("Pred: " + rgb_pred_str + " Tar: " + rgb_tar_str, rotation=0, labelpad=100)
|
98 |
+
# Ensure nice border between colors
|
99 |
+
ax.set_xticklabels(["" for _ in range(n)])
|
100 |
+
# The proper way to set no ticks
|
101 |
+
ax.yaxis.set_major_locator(ticker.NullLocator())
|
102 |
+
|
103 |
+
# Config for plot
|
104 |
+
rows = num_of_plots//3 + 1
|
105 |
+
cols = 3
|
106 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
107 |
+
|
108 |
+
|
109 |
+
column_titles = [" Prediction | Target " for i in range(cols)]
|
110 |
+
for ax, col in zip(ax_array[0], column_titles):
|
111 |
+
ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
|
112 |
+
|
113 |
+
fig.suptitle(model_name+" Test Palettes", fontsize=100)
|
114 |
+
|
115 |
+
# Code for evaluation loop
|
116 |
+
plot_count = 0
|
117 |
+
val_losses = []
|
118 |
+
palettes = []
|
119 |
+
preds = []
|
120 |
+
targets = []
|
121 |
+
count = 0
|
122 |
+
top_k_all_preds = []
|
123 |
+
node_names = []
|
124 |
+
|
125 |
+
for i, (input_data, target_color, node_to_mask) in enumerate(test_loader):
|
126 |
+
loss, out = test(input_data.to(device), target_color.to(device), node_to_mask)
|
127 |
+
if loss != None:
|
128 |
+
val_losses.append(loss.item())
|
129 |
+
|
130 |
+
# Get predicton and other colors in the palette
|
131 |
+
ax = plt.subplot(rows, cols, plot_count+1)
|
132 |
+
|
133 |
+
# Get prediction for a masked node
|
134 |
+
|
135 |
+
prediction = out[node_to_mask, :]
|
136 |
+
#preds.append(prediction.detach().cpu()[0])
|
137 |
+
#targets.append(target_color.detach().cpu()[0])
|
138 |
+
# Concat unmasked colors with prediction and ground truth
|
139 |
+
other_colors = torch.tensor(input_data.y)
|
140 |
+
# FIX THIS
|
141 |
+
other_colors = torch.cat((other_colors[0][0:node_to_mask], other_colors[0][node_to_mask+1:]))
|
142 |
+
#print(other_colors[0:node_to_mask, :].shape, other_colors[node_to_mask+1:, :].shape, node_to_mask)
|
143 |
+
other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
|
144 |
+
# Normalize since they are in (0, 255) range.
|
145 |
+
# other_colors /= 255
|
146 |
+
|
147 |
+
# Concat palettes. All of them are between (0, 1)
|
148 |
+
a, top_k_preds = F.softmax(torch.tensor(prediction)).topk(5)
|
149 |
+
top_k_all_preds.append(top_k_preds.detach().numpy())
|
150 |
+
node_names.append(node_to_mask.detach().numpy())
|
151 |
+
prediction = torch.argmax(F.softmax(torch.tensor(prediction)), dim=1).numpy()
|
152 |
+
print("Pred: ", prediction, " Target: ", target_color)
|
153 |
+
preds.append(prediction.item())
|
154 |
+
print("pred cube to color: ", test_dataset.cube2rgb(prediction.item()))
|
155 |
+
print("target cube to color: ", test_dataset.cube2rgb(target_color.item()))
|
156 |
+
targets.append(target_color.item())
|
157 |
+
# print(other_colors, prediction, target_color)
|
158 |
+
# print(other_colors[0])
|
159 |
+
# print(np.atleast_1d(prediction), np.atleast_1d(target_color.detach().cpu().numpy()))
|
160 |
+
palette = np.concatenate((other_colors, np.atleast_1d(prediction.item()), np.atleast_1d(target_color.item())))
|
161 |
+
pred_palette = np.concatenate((top_k_preds[0], np.atleast_1d(target_color.item())))
|
162 |
+
# I commented out codes related to calculating results in CIELab
|
163 |
+
# if "embedding" in model_name.lower():
|
164 |
+
# other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
|
165 |
+
# other_colors /= 255
|
166 |
+
# palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
|
167 |
+
# else:
|
168 |
+
# current_palette = torch.cat([other_colors, prediction, target_color.to(device)]).type(torch.float32).detach().cpu().numpy()
|
169 |
+
# palette = CIELab2RGB(current_palette)
|
170 |
+
|
171 |
+
# Save all the palettes to use it for distribution histograms.
|
172 |
+
|
173 |
+
all_colors = []
|
174 |
+
for e, num in enumerate(palette):
|
175 |
+
color = test_dataset.cube2rgb(num)
|
176 |
+
all_colors.append(color)
|
177 |
+
|
178 |
+
#palettes.append(prediction.detach().tolist()[0])
|
179 |
+
my_palplot(all_colors, prediction.item(), target_color.item(), ax=ax)
|
180 |
+
else:
|
181 |
+
print("none")
|
182 |
+
|
183 |
+
plot_count+=1
|
184 |
+
print(plot_count)
|
185 |
+
if i == num_of_plots-1:
|
186 |
+
path = "../models/"+model_name
|
187 |
+
if not os.path.exists(path):
|
188 |
+
os.mkdir(path)
|
189 |
+
|
190 |
+
if our_node_to_mask == -1:
|
191 |
+
print("hello")
|
192 |
+
plt.savefig(path+"/palettes.jpg")
|
193 |
+
|
194 |
+
else:
|
195 |
+
print("why red")
|
196 |
+
plt.savefig(path+"/palettes_only_red.jpg")
|
197 |
+
|
198 |
+
N = len(test_dataset)
|
199 |
+
print("Evaluation Accuracy: ", np.sum(np.array(preds) == np.array(targets))/N)
|
200 |
+
random_results = (np.random.random(size=(N,))*64).astype(np.uint16)
|
201 |
+
print("Evaluation Accuracy Random: ", np.sum(random_results == np.array(targets))/N)
|
202 |
+
|
203 |
+
np.save("preds.npy", np.array(preds))
|
204 |
+
np.save("targets.npy", np.array(targets))
|
205 |
+
np.save("node_to_mask.npy", np.array(node_names))
|
206 |
+
np.save("top_k_preds.npy", np.array(top_k_all_preds))
|
207 |
+
|
208 |
+
# This is for checking prediction distribution
|
209 |
+
# It is saved as a histogram.
|
210 |
+
#check_distributions(palettes)
|
211 |
+
|
212 |
+
# criterion = nn.MSELoss()
|
213 |
+
# stacked_pred = torch.stack(preds)
|
214 |
+
# stacked_target = torch.stack(targets)
|
215 |
+
# random_results = np.random.random(size=(stacked_pred.shape[0], stacked_pred.shape[1]))
|
216 |
+
# print(criterion(stacked_pred, stacked_target))
|
217 |
+
# print(criterion(torch.Tensor(random_results), stacked_target))
|
color_palette/evaluate_recommend.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib as mpl
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib.ticker as ticker
|
5 |
+
from regressor.config import config_to_use
|
6 |
+
|
7 |
+
def my_palplot(pal, size=1, ax=None):
|
8 |
+
|
9 |
+
n = len(pal)
|
10 |
+
if ax is None:
|
11 |
+
f, ax = plt.subplots(1, 1, figsize=(n * size, size))
|
12 |
+
ax.imshow(np.arange(n).reshape(1, n),
|
13 |
+
cmap=mpl.colors.ListedColormap(list(pal)),
|
14 |
+
interpolation="nearest", aspect="auto")
|
15 |
+
ax.set_xticks(np.arange(n) - .5)
|
16 |
+
ax.set_yticks([-.5, .5])
|
17 |
+
# Ensure nice border between colors
|
18 |
+
ax.set_xticklabels(["" for _ in range(n)])
|
19 |
+
# The proper way to set no ticks
|
20 |
+
ax.yaxis.set_major_locator(ticker.NullLocator())
|
21 |
+
|
22 |
+
# test_preds = np.load('all_one_hot_LR/test_preds_graph.npy')
|
23 |
+
palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
|
24 |
+
original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
|
25 |
+
|
26 |
+
print("testing out stuff")
|
27 |
+
|
28 |
+
# test_preds = np.expand_dims(test_preds, axis=1)
|
29 |
+
|
30 |
+
# all_colors = np.concatenate((palettes, test_preds), axis=1)
|
31 |
+
|
32 |
+
colors = np.clip(palettes, a_min=0, a_max=1)
|
33 |
+
colors_org = np.clip(original_palettes, a_min=0, a_max=1)
|
34 |
+
|
35 |
+
rows = 50
|
36 |
+
cols = 2
|
37 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
38 |
+
|
39 |
+
column_titles = ["Updated palettes" for i in range(cols)]
|
40 |
+
for ax, col in zip(ax_array[0], column_titles):
|
41 |
+
ax.set_title(col, fontdict={'fontsize': 50, 'fontweight': 'medium'})
|
42 |
+
|
43 |
+
fig.suptitle("Test Palettes", fontsize=100)
|
44 |
+
plot_count = 0
|
45 |
+
for i in range(len(colors)):
|
46 |
+
ax = plt.subplot(rows, cols, plot_count+1)
|
47 |
+
my_palplot(colors[i], ax=ax)
|
48 |
+
plot_count += 1
|
49 |
+
if plot_count == 100:
|
50 |
+
break
|
51 |
+
|
52 |
+
plt.savefig(config_to_use.save_folder+'/recommended_purple.jpg')
|
53 |
+
plt.clf()
|
54 |
+
|
55 |
+
rows = 50
|
56 |
+
cols = 2
|
57 |
+
fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
|
58 |
+
|
59 |
+
column_titles = ["Original Palettes" for i in range(cols)]
|
60 |
+
for ax, col in zip(ax_array[0], column_titles):
|
61 |
+
ax.set_title(col, fontdict={'fontsize': 50, 'fontweight': 'medium'})
|
62 |
+
|
63 |
+
fig.suptitle("Test Palettes", fontsize=100)
|
64 |
+
plot_count = 0
|
65 |
+
for i in range(len(colors)):
|
66 |
+
ax = plt.subplot(rows, cols, plot_count+1)
|
67 |
+
my_palplot(colors_org[i], ax=ax)
|
68 |
+
plot_count += 1
|
69 |
+
if plot_count == 100:
|
70 |
+
break
|
71 |
+
|
72 |
+
plt.savefig(config_to_use.save_folder+'/original_purple.jpg')
|
color_palette/model/CNN.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
|
4 |
+
|
5 |
+
class Autoencoder(torch.nn.Module):
|
6 |
+
def __init__(self, num_channels=3, c_hid=16, latent_dim=256):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# Building an linear encoder with Linear
|
10 |
+
# layer followed by Relu activation function
|
11 |
+
# 784 ==> 9
|
12 |
+
self.encoder = torch.nn.Sequential(
|
13 |
+
nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 256x256 => 128x128
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 64x64 => 32x32
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
|
24 |
+
)
|
25 |
+
|
26 |
+
# Building an linear decoder with Linear
|
27 |
+
# layer followed by Relu activation function
|
28 |
+
# The Sigmoid activation function
|
29 |
+
# outputs the value between 0 and 1
|
30 |
+
# 9 ==> 784
|
31 |
+
self.linear = nn.Sequential(
|
32 |
+
nn.Flatten(), # Image grid to single feature vector
|
33 |
+
nn.Linear(16 * 16 * c_hid, num_channels*latent_dim),
|
34 |
+
)
|
35 |
+
self.decoder = nn.Sequential(
|
36 |
+
nn.ConvTranspose2d(
|
37 |
+
num_channels, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
|
38 |
+
), # 4x4 => 8x8
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
|
41 |
+
nn.ReLU(),
|
42 |
+
nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
|
43 |
+
nn.ReLU(),
|
44 |
+
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.ConvTranspose2d(
|
47 |
+
c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
|
48 |
+
), # 16x16 => 32x32
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.ConvTranspose2d(
|
51 |
+
c_hid, num_channels, kernel_size=3, output_padding=1, padding=1, stride=2
|
52 |
+
), # 32x32 => 64x64
|
53 |
+
nn.Tanh(), # The input images is scaled between -1 and 1, hence the output has to be bounded as well
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
encoded = self.encoder(x)
|
58 |
+
linear = self.linear(encoded)
|
59 |
+
x = linear.reshape(linear.shape[0], -1, 16, 16)
|
60 |
+
decoded = self.decoder(x)
|
61 |
+
return decoded
|
62 |
+
|
63 |
+
class FinetuneResNet18_classify(nn.Module):
|
64 |
+
def __init__(self, freeze_resnet=True):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
"Classify the color"
|
68 |
+
|
69 |
+
self.pretrained_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
70 |
+
|
71 |
+
for param in self.pretrained_model.parameters():
|
72 |
+
param.requires_grad_ = False
|
73 |
+
|
74 |
+
self.pretrained_model.fc = nn.Linear(in_features=512, out_features=1024)
|
75 |
+
|
76 |
+
self.color_head = nn.Sequential(
|
77 |
+
nn.Linear(in_features=1024, out_features=768),
|
78 |
+
)
|
79 |
+
self.softmax = nn.Softmax(dim=1)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.pretrained_model(x)
|
83 |
+
x = self.color_head(x)
|
84 |
+
r = self.softmax(x[:, :256])
|
85 |
+
g = self.softmax(x[:, 256:512])
|
86 |
+
b = self.softmax(x[:, 512:])
|
87 |
+
return r, g, b
|
88 |
+
|
89 |
+
class ResNet18(nn.Module):
|
90 |
+
def __init__(self, freeze_resnet=True, map_outputs="RGB"):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
"""
|
94 |
+
Just map to interval
|
95 |
+
"""
|
96 |
+
self.pretrained_model = resnet18(weights=None)
|
97 |
+
self.map_outputs = map_outputs
|
98 |
+
|
99 |
+
for param in self.pretrained_model.parameters():
|
100 |
+
param.requires_grad_ = True
|
101 |
+
|
102 |
+
self.pretrained_model.fc = nn.Linear(in_features=512, out_features=256)
|
103 |
+
|
104 |
+
self.color_head = nn.Sequential(
|
105 |
+
nn.Linear(in_features=256, out_features=128),
|
106 |
+
nn.Linear(in_features=128, out_features=64),
|
107 |
+
nn.Linear(in_features=64, out_features=3),
|
108 |
+
)
|
109 |
+
|
110 |
+
if self.map_outputs == "CIELab":
|
111 |
+
self.l_activation = nn.Sigmoid()
|
112 |
+
self.a_activation = nn.Tanh()
|
113 |
+
self.b_activation = nn.Tanh()
|
114 |
+
|
115 |
+
elif self.map_outputs == "RGB":
|
116 |
+
self.activation = nn.Sigmoid()
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = self.pretrained_model(x)
|
120 |
+
x = self.color_head(x)
|
121 |
+
if self.map_outputs == "CIElab":
|
122 |
+
x[:, 0] = self.l_activation(x[:, 0]) * 100
|
123 |
+
x[:, 1] = self.a_activation(x[:, 1]) * 127
|
124 |
+
x[:, 2] = self.b_activation(x[:, 2]) * 127
|
125 |
+
|
126 |
+
elif self.map_outputs == "RGB":
|
127 |
+
x = x
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class ColorCNN(nn.Module):
|
132 |
+
def __init__(self, num_channels=3, c_hid=16, out_feature=3, reverse_normalize_output=True):
|
133 |
+
super().__init__()
|
134 |
+
self.encoder = torch.nn.Sequential(
|
135 |
+
nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 512x512 -> 256x256
|
136 |
+
nn.BatchNorm2d(num_features=c_hid),
|
137 |
+
nn.ReLU(),
|
138 |
+
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
|
139 |
+
nn.ReLU(),
|
140 |
+
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 256x256 -> 128x128
|
141 |
+
nn.BatchNorm2d(num_features=2*c_hid),
|
142 |
+
nn.ReLU(),
|
143 |
+
nn.Conv2d(2 * c_hid, 4 * c_hid, kernel_size=3, padding=1),
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.Conv2d(4 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
|
146 |
+
nn.BatchNorm2d(num_features=2*c_hid),
|
147 |
+
nn.ReLU(),
|
148 |
+
nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 64x64 => 32x32
|
149 |
+
nn.BatchNorm2d(num_features=c_hid),
|
150 |
+
nn.ReLU(),
|
151 |
+
nn.Conv2d(c_hid, 8, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
|
152 |
+
nn.BatchNorm2d(num_features=8),
|
153 |
+
nn.ReLU(),
|
154 |
+
)
|
155 |
+
|
156 |
+
self.color_head = nn.Sequential(
|
157 |
+
nn.Flatten(),
|
158 |
+
nn.Linear(in_features=16*16*8, out_features=out_feature)
|
159 |
+
)
|
160 |
+
|
161 |
+
self.reverse_normalize_output = reverse_normalize_output
|
162 |
+
#self.addition = Addition()
|
163 |
+
#self.multiplication = torch.Tensor([100, 255, 255]).to("cuda:1")
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
x = self.encoder(x)
|
167 |
+
x = self.color_head(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
class Addition(nn.Module):
|
171 |
+
def __init__(self) -> None:
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
175 |
+
input += torch.Tensor([0, 127, 127]).to("cuda:1")
|
176 |
+
return input
|
177 |
+
|
178 |
+
class ColorCNNBigger(nn.Module):
|
179 |
+
def __init__(self, num_channels=3, c_hid=16):
|
180 |
+
super().__init__()
|
181 |
+
self.encoder = torch.nn.Sequential(
|
182 |
+
nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 512x512 -> 256x256
|
183 |
+
nn.BatchNorm2d(num_features=c_hid),
|
184 |
+
nn.ReLU(),
|
185 |
+
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1),
|
186 |
+
nn.ReLU(),
|
187 |
+
nn.Conv2d(2 * c_hid, 4 * c_hid, kernel_size=3, padding=1, stride=2), # 256x256 -> 128x128
|
188 |
+
nn.BatchNorm2d(num_features=4*c_hid),
|
189 |
+
nn.ReLU(),
|
190 |
+
nn.Conv2d(4 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
|
191 |
+
nn.ReLU(),
|
192 |
+
nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
|
193 |
+
nn.BatchNorm2d(num_features=c_hid),
|
194 |
+
nn.ReLU(),
|
195 |
+
nn.Conv2d(c_hid, c_hid//2, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
|
196 |
+
nn.BatchNorm2d(num_features=c_hid//2),
|
197 |
+
)
|
198 |
+
|
199 |
+
self.color_head = nn.Sequential(
|
200 |
+
nn.Flatten(),
|
201 |
+
nn.Linear(in_features=32*32*8, out_features=16*16*4),
|
202 |
+
nn.Linear(in_features=16*16*4, out_features=3)
|
203 |
+
)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
x = self.encoder(x)
|
207 |
+
x = self.color_head(x)
|
208 |
+
return x
|
209 |
+
|