busraasan commited on
Commit
7625832
·
1 Parent(s): 690bdb7
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. app.py +397 -0
  4. color_palette/._dataset_processing.py +0 -0
  5. color_palette/__pycache__/cnn_dataset.cpython-39.pyc +0 -0
  6. color_palette/__pycache__/config.cpython-38.pyc +0 -0
  7. color_palette/__pycache__/config.cpython-39.pyc +0 -0
  8. color_palette/__pycache__/dataset.cpython-38.pyc +0 -0
  9. color_palette/__pycache__/dataset.cpython-39.pyc +0 -0
  10. color_palette/__pycache__/dataset_processing.cpython-39.pyc +0 -0
  11. color_palette/__pycache__/train_CNN.cpython-39.pyc +0 -0
  12. color_palette/__pycache__/utils.cpython-38.pyc +0 -0
  13. color_palette/__pycache__/utils.cpython-39.pyc +0 -0
  14. color_palette/all_one_hot_LR/test_gt.npy +3 -0
  15. color_palette/all_one_hot_LR/test_preds.npy +3 -0
  16. color_palette/all_one_hot_LR/test_preds_graph.npy +3 -0
  17. color_palette/all_one_hot_LR/test_rgb_colors.npy +3 -0
  18. color_palette/all_one_hot_LR_sequential/new_palettes.npy +3 -0
  19. color_palette/all_one_hot_LR_sequential/original_palettes.npy +3 -0
  20. color_palette/all_one_hot_LR_sequential/test_gt.npy +3 -0
  21. color_palette/all_one_hot_LR_sequential/test_preds.npy +3 -0
  22. color_palette/app copy.py +326 -0
  23. color_palette/bash_scripts/training.sh +24 -0
  24. color_palette/cnn_dataset.py +104 -0
  25. color_palette/colorCNN.py +238 -0
  26. color_palette/config.py +29 -0
  27. color_palette/config/conf.yaml +15 -0
  28. color_palette/config/confCNN.yaml +16 -0
  29. color_palette/config/grid_search_conf_generator.py +24 -0
  30. color_palette/cube_num_one_hot_LR/test_gt.npy +3 -0
  31. color_palette/cube_num_one_hot_LR/test_preds.npy +3 -0
  32. color_palette/cube_num_one_hot_LR/test_preds_graph.npy +3 -0
  33. color_palette/cube_num_one_hot_LR/test_rgb_colors.npy +3 -0
  34. color_palette/cube_num_one_hot_LR_sequential/new_palettes.npy +3 -0
  35. color_palette/cube_num_one_hot_LR_sequential/new_palettes_purple.npy +3 -0
  36. color_palette/cube_num_one_hot_LR_sequential/original_palettes.npy +3 -0
  37. color_palette/cube_num_one_hot_LR_sequential/original_palettes_purple.npy +3 -0
  38. color_palette/cube_num_one_hot_LR_sequential/test_gt.npy +3 -0
  39. color_palette/cube_num_one_hot_LR_sequential/test_preds.npy +3 -0
  40. color_palette/dataset.py +215 -0
  41. color_palette/dataset_processing.py +505 -0
  42. color_palette/deneme.png +0 -0
  43. color_palette/deneme.py +107 -0
  44. color_palette/denemeler.ipynb +0 -0
  45. color_palette/dist.png +0 -0
  46. color_palette/evaluate.py +188 -0
  47. color_palette/evaluate_CNN.py +180 -0
  48. color_palette/evaluate_classification.py +217 -0
  49. color_palette/evaluate_recommend.py +72 -0
  50. color_palette/model/CNN.py +209 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ color_palette/regressor/colorLoversData.mat filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /destijl_dataset
2
+ *.jpg
3
+ *.log
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from color_palette.regressor.config import config_to_use
2
+ import numpy as np
3
+ from PIL import Image, ImageFont, ImageDraw
4
+ from sklearn.linear_model import LinearRegression
5
+ from color_palette.model.GNN import ColorAttentionClassification
6
+ from color_palette.regressor.model import Color2CubeDataset
7
+ from color_palette.regressor.config import *
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from color_palette.dataset import GraphDestijlDataset
10
+ from color_palette.config import DataConfig
11
+ import random
12
+ import os
13
+ import torch.nn.functional as F
14
+ import torch
15
+ import gradio as gr
16
+
17
+ config = DataConfig()
18
+ model_name = config.model_name
19
+ dataset_root = config.dataset
20
+ feature_size = config.feature_size
21
+ device = config.device
22
+ image_folder = "img_folder"
23
+
24
+ if not os.path.exists(image_folder):
25
+ os.mkdir(image_folder)
26
+
27
+ def train_regressor(train_loader):
28
+ X = []
29
+ y = []
30
+ for i, (input_data, target) in enumerate(train_loader):
31
+ input_data = np.squeeze(input_data)
32
+ target = np.squeeze(target)
33
+ X.append(input_data)
34
+ y.append(target)
35
+
36
+ X = np.stack(X, axis=0)
37
+ y = np.squeeze(np.stack(y, axis=0))
38
+
39
+ print("Before regressor train!\n")
40
+
41
+ reg = LinearRegression().fit(X, y)
42
+
43
+ return reg
44
+
45
+ model_weight_path = "models/" + model_name + "/weights/best.pth"
46
+
47
+ # palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
48
+ # original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
49
+
50
+ graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
51
+ model = ColorAttentionClassification(feature_size).to(device)
52
+ model.load_state_dict(torch.load(model_weight_path)["state_dict"])
53
+
54
+ dataset = Color2CubeDataset(config=config_to_use)
55
+ train_loader = DataLoader(dataset, batch_size=1, shuffle=False)
56
+ regressor = train_regressor(train_loader=train_loader)
57
+
58
+ palette_of_the_design = [[0, 0, 0] for i in range(5)]
59
+ all_node_colors = None
60
+ class Demo:
61
+ def __init__(self, graph_dataset):
62
+ self.dataset = graph_dataset
63
+
64
+ first_sample_idx = random.randint(0, len(self.dataset)-1)
65
+ self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
66
+ global all_node_colors
67
+ all_node_colors = also_normal_values
68
+ self.same_indices = None
69
+ self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
70
+
71
+ def demo_reset(self):
72
+ first_sample_idx = random.randint(0, len(self.dataset))
73
+ self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
74
+ global all_node_colors
75
+ all_node_colors = also_normal_values
76
+ self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
77
+
78
+ def generate_img_from_palette(self, palette, canvas_size=512, is_first=False):
79
+ palette = np.array(palette).astype('int')
80
+ rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette]
81
+ if is_first:
82
+ self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette)
83
+ else:
84
+ _, unique_colors, _ = self.return_all_same_colors(palette=palette)
85
+
86
+ # assign the current palette using global keyword
87
+ global palette_of_the_design
88
+ palette_of_the_design = unique_colors
89
+
90
+ # Set the background color and create an empty PIL Image to fill with shapes and text
91
+ image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg)
92
+ # Save background image
93
+
94
+ title = "Lorem Ipsum Dolor"
95
+ undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..."
96
+
97
+ draw = ImageDraw.Draw(image)
98
+ # Set settings for the fonts
99
+ font_title = ImageFont.truetype("Arial.ttf", 32)
100
+ title_width, title_height = draw.textsize(title, font=font_title)
101
+ title_x = (canvas_size - title_width) // 2
102
+ title_y = (canvas_size - title_height) // 2 - 100
103
+
104
+ font_undertitle = ImageFont.truetype("Arial.ttf", 15)
105
+ text_width, text_height = draw.textsize(undertitle, font=font_undertitle)
106
+ undertitle_x = (canvas_size - text_width) // 2
107
+ undertitle_y = (canvas_size - text_height) // 2 - 50
108
+ # Draw titles
109
+ draw.text((title_x, title_y), title, fill=rgb_text, font=font_title)
110
+ draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle)
111
+
112
+ # Draw the circle
113
+ rad = random.randint(30, 70)
114
+ x = random.randint(400, 512-(rad+10))
115
+ y = random.randint(10, title_y-(rad+10))
116
+
117
+ draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle)
118
+
119
+ # Draw the image
120
+ for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]):
121
+ x = 512-((j+1)*60)
122
+ y = 512-((j+1)*60)
123
+
124
+ if j == 0:
125
+ rad = 80
126
+ draw.rectangle((x, y, x+rad, y+rad), fill=color)
127
+ else:
128
+ rad = 40
129
+ draw.rectangle((x, y, x+rad, y+rad), fill=color)
130
+
131
+ image.save(os.path.join("deneme.png"))
132
+
133
+ def run_model(self, input_data, target_color, node_to_mask, updated_color):
134
+
135
+ global all_node_colors
136
+ palette = np.array([color.detach().numpy()*255 for color in all_node_colors]).astype('int')
137
+ same_indices_list, unique_colors, first_indices = self.return_all_same_colors(palette)
138
+
139
+ unique_colors = unique_colors/255
140
+
141
+ selected_color = torch.Tensor(updated_color)/255
142
+ map_node_to_mask = -1
143
+ print("same indices list")
144
+ print(self.same_indices)
145
+ print("node to mask")
146
+ print(node_to_mask)
147
+ for i, idxs in enumerate(self.same_indices):
148
+ if node_to_mask in idxs:
149
+ map_node_to_mask = i
150
+ print("map node to mask: ", i)
151
+
152
+ for i, indices in enumerate(self.same_indices):
153
+
154
+ if i == 0:
155
+ # update the color [0.15, 0.4908, 0.73]
156
+ cube_num_of_selected = self.rgb2cube(selected_color*255)
157
+ one_hot = np.zeros((64,))
158
+ one_hot[int(cube_num_of_selected)] = 1.0
159
+ node_to_recommend = map_node_to_mask
160
+ input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot)
161
+ unique_colors[map_node_to_mask] = selected_color
162
+ else:
163
+ if i == map_node_to_mask:
164
+ zeroth_bin = 0
165
+ indices = same_indices_list[zeroth_bin]
166
+ node_to_recommend = 0
167
+ input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
168
+ node_to_mask = indices[0]
169
+ else:
170
+ node_to_recommend = i
171
+ input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
172
+ node_to_mask = indices[0]
173
+
174
+ out = self.forward_pass(model, input_data) # input data has one-hot color features
175
+ if torch.is_tensor(node_to_mask):
176
+ node_to_mask = node_to_mask.item()
177
+ values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation
178
+ prediction = values_indices.detach().numpy()[2]
179
+ # construct a palette using unique RGB palette and one-hot representation of the prediction cube.
180
+ feature_vector = self.create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend)
181
+ # map cube to rgb color space using the regressor
182
+ recommendation = regressor.predict(feature_vector)[0]
183
+ # we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information.
184
+ # update the color in the palette and run the algorithm for rest of the palette.
185
+ # for that, first map the color to cube and convert to one_hot
186
+ input_data, unique_colors = self.update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend)
187
+ # recursively do this here.
188
+ # save the results.
189
+ return np.array(unique_colors*255).astype(int)
190
+
191
+ def rgb2cube(self, color):
192
+ intervals = np.arange(0, 256, 256//4)
193
+ cube_coordinates = []
194
+ for channel in color:
195
+ i = 0
196
+ for j, value in enumerate(intervals):
197
+ if value < channel:
198
+ i = j
199
+ cube_coordinates.append(i)
200
+
201
+ cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4
202
+ return cube_num
203
+
204
+ def cube2rgb(self, cube_num):
205
+ """
206
+ Return the start of the ranges
207
+ """
208
+ cube_num = int(cube_num)
209
+ intervals = np.arange(0, 256, 256//4)
210
+ coor2 = cube_num // 16
211
+ coor1 = (cube_num - coor2*4*4) // 4
212
+ coor0 = cube_num - coor2*4*4 - coor1*4
213
+ return [intervals[coor0], intervals[coor1], intervals[coor2]]
214
+
215
+ def return_all_same_colors(self, palette):
216
+
217
+ indices_list = [[],[],[],[],[]]
218
+ unique_colors, first_indices = np.unique(palette, axis=0, return_index=True)
219
+
220
+ unique_colors = np.array(unique_colors)
221
+ all_colors = np.array(palette)
222
+
223
+ for idx, color in enumerate(unique_colors):
224
+ for node_num, element in enumerate(all_colors):
225
+ if np.equal(color, element).all():
226
+ indices_list[idx].append(node_num)
227
+
228
+ # these palettes and indices also include the masked color
229
+ return indices_list, unique_colors, first_indices
230
+
231
+ def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs):
232
+ # convert prediction to one-hot vector
233
+ cube_num_of_the_changed_color = self.rgb2cube(recommendation*255)
234
+ one_hot = np.zeros((64,))
235
+ one_hot[int(cube_num_of_the_changed_color)] = 1.0
236
+
237
+ # update the feature vector accordingly for all the same colors
238
+ for idx in indices_list[idx_to_idxs]:
239
+ input_data.x[idx, 4:] = torch.Tensor(one_hot)
240
+
241
+ # update the unique color vector
242
+ unique_rgb_palette[idx_to_idxs] = recommendation
243
+ return input_data, unique_rgb_palette
244
+
245
+
246
+ def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask):
247
+ one_hot = np.zeros((64,))
248
+ one_hot[int(cube_num)] = 1.0
249
+ removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
250
+ feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0)
251
+ return feature_vector.reshape(1, -1)
252
+
253
+ def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask):
254
+ one_hot = np.zeros((64,))
255
+ one_hot[int(cube_num)] = 1.0
256
+ removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
257
+ new_input_data = []
258
+ for color in removed_palette:
259
+ color_cube_num = self.rgb2cube(color*255)
260
+ empty_arr = np.zeros((64,))
261
+ empty_arr[int(color_cube_num)] = 1.0
262
+ new_input_data.append(empty_arr)
263
+
264
+ feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0)
265
+ return feature_vector.reshape(1, -1)
266
+
267
+ def forward_pass(self, model, data):
268
+ model.eval()
269
+ out = model(data.x, data.edge_index.long(), data.edge_weight)
270
+ return out
271
+
272
+ def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette):
273
+ # take the node_to_mask indices to the beginning of the list
274
+ for i in range(len(indices_list)):
275
+ if node_to_mask in indices_list[i]:
276
+ index_to_pop = i
277
+
278
+ idxs = indices_list.pop(index_to_pop)
279
+ palette = unique_rgb_palette[index_to_pop]
280
+ temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0)
281
+ unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0)
282
+ return [idxs] + indices_list, unique_rgb_palette
283
+
284
+ def update_color(self, updated_color, idx):
285
+ """
286
+ Takes a color and assigns it to the palette and the image.
287
+ """
288
+ idx = int(idx)
289
+ color = updated_color[1:-1].split(",")
290
+ color = [int(num) for num in color]
291
+
292
+ index_list = self.same_indices[idx]
293
+ which_one = random.randint(0, len(index_list)-1)
294
+ idx_to_update = index_list[which_one]
295
+
296
+ unique_colors = self.run_model(self.input_data, self.target_color, idx_to_update, color)
297
+
298
+ global palette_of_the_design
299
+ palette_of_the_design = unique_colors
300
+
301
+ global all_node_colors
302
+ if torch.is_tensor(all_node_colors):
303
+ all_node_colors = all_node_colors.detach().numpy()
304
+ for i, index_list in enumerate(self.same_indices):
305
+ for index in index_list:
306
+ all_node_colors[index] = unique_colors[i]
307
+
308
+ self.generate_img_from_palette(palette=[color for color in all_node_colors])
309
+ main_image = Image.open("deneme.png")
310
+ gradio_elements = []
311
+ gradio_elements.append(gr.Image(main_image, height=256, width=256))
312
+ for i in range(len(self.same_indices)):
313
+ color = unique_colors[i]
314
+ image = Image.new("RGB", (512, 512), color=tuple(color))
315
+ gradio_elements.append(gr.Image(image, height=64, width=64))
316
+ string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
317
+ gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
318
+
319
+ all_node_colors = torch.Tensor(all_node_colors) / 255
320
+ return tuple(gradio_elements)
321
+
322
+ def perform_reset(button_input):
323
+ global demo
324
+ global all_node_colors
325
+ gradio_elements = []
326
+
327
+ demo.demo_reset()
328
+ main_image = Image.open("deneme.png")
329
+ gradio_elements = []
330
+ gradio_elements.append(gr.Image(main_image, height=256, width=256))
331
+
332
+ for color in palette_of_the_design:
333
+ image = Image.new("RGB", (512, 512), color=tuple(color))
334
+ gradio_elements.append(gr.Image(image, height=64, width=64))
335
+ string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
336
+ gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
337
+
338
+ return tuple(gradio_elements)
339
+
340
+
341
+ demo = Demo(graph_dataset=graph_test_dataset)
342
+
343
+ # Form a gradio template to display images and update the colors.
344
+
345
+ with gr.Blocks() as project_demo:
346
+ with gr.Row():
347
+ image = Image.open("deneme.png")
348
+ design = gr.Image(image, height=256, width=256)
349
+
350
+ with gr.Row():
351
+ with gr.Column(min_width=100):
352
+ image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0]))
353
+ image1_gr = gr.Image(image1, height=64, width=64)
354
+ string1 = "["+str(palette_of_the_design[0][0])+", "+ str(palette_of_the_design[0][1])+", " + str(palette_of_the_design[0][2])+"]"
355
+ color1_update = gr.Textbox(value=string1, min_width=64)
356
+ color1_button = gr.Button(value="Update Color 1", min_width=64)
357
+ with gr.Column(min_width=100):
358
+ image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1]))
359
+ image2_gr = gr.Image(image2, height=64, width=64)
360
+ string2 = "["+str(palette_of_the_design[1][0])+", "+ str(palette_of_the_design[1][1])+", " + str(palette_of_the_design[1][2])+"]"
361
+ color2_update = gr.Textbox(value=string2, min_width=64)
362
+ color2_button = gr.Button(value="Update Color 2", min_width=64)
363
+ with gr.Column(min_width=100):
364
+ image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2]))
365
+ image3_gr = gr.Image(image3, height=64, width=64)
366
+ string3 = "["+str(palette_of_the_design[2][0])+", "+ str(palette_of_the_design[2][1])+", " + str(palette_of_the_design[2][2])+"]"
367
+ color3_update = gr.Textbox(value=string3, min_width=64)
368
+ color3_button = gr.Button(value="Update Color 3", min_width=64)
369
+ with gr.Column(min_width=100):
370
+ image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3]))
371
+ image4_gr = gr.Image(image4, height=64, width=64)
372
+ string4 = "["+str(palette_of_the_design[3][0])+", "+ str(palette_of_the_design[3][1])+", " + str(palette_of_the_design[3][2])+"]"
373
+ color4_update = gr.Textbox(value=string4, min_width=64)
374
+ color4_button = gr.Button(value="Update Color 4", min_width=64)
375
+ with gr.Column(min_width=100):
376
+ image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4]))
377
+ image5_gr = gr.Image(image5, height=64, width=64)
378
+ string5 = "["+str(palette_of_the_design[4][0])+", "+ str(palette_of_the_design[4][1])+", " + str(palette_of_the_design[4][2])+"]"
379
+ color5_update = gr.Textbox(value=string5, min_width=64)
380
+ color5_button = gr.Button(value="Update Color 5", min_width=64)
381
+
382
+ with gr.Row():
383
+ reset_button = gr.Button(value="Reset the palette", min_width=64)
384
+
385
+ zero = gr.Number(value=0, visible=False)
386
+ one = gr.Number(value=1, visible=False)
387
+ two = gr.Number(value=2, visible=False)
388
+ three = gr.Number(value=3, visible=False)
389
+ four = gr.Number(value=4, visible=False)
390
+ color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
391
+ color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
392
+ color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
393
+ color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
394
+ color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
395
+ reset_button.click(fn=perform_reset, inputs=[reset_button], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
396
+
397
+ project_demo.launch()
color_palette/._dataset_processing.py ADDED
Binary file (4.1 kB). View file
 
color_palette/__pycache__/cnn_dataset.cpython-39.pyc ADDED
Binary file (2.76 kB). View file
 
color_palette/__pycache__/config.cpython-38.pyc ADDED
Binary file (813 Bytes). View file
 
color_palette/__pycache__/config.cpython-39.pyc ADDED
Binary file (734 Bytes). View file
 
color_palette/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (5.84 kB). View file
 
color_palette/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (5.24 kB). View file
 
color_palette/__pycache__/dataset_processing.cpython-39.pyc ADDED
Binary file (12.9 kB). View file
 
color_palette/__pycache__/train_CNN.cpython-39.pyc ADDED
Binary file (4.42 kB). View file
 
color_palette/__pycache__/utils.cpython-38.pyc ADDED
Binary file (9.73 kB). View file
 
color_palette/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.68 kB). View file
 
color_palette/all_one_hot_LR/test_gt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
3
+ size 48128
color_palette/all_one_hot_LR/test_preds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b00a13c1ee009c195b1ffd4de3c502b4d34a72d54634e99a4e73a8d861a01292
3
+ size 48128
color_palette/all_one_hot_LR/test_preds_graph.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9b938717cae80d2fc604e38be8ffaec15bd95fd30c8294aec0753e79b51ffce
3
+ size 2312
color_palette/all_one_hot_LR/test_rgb_colors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56dff61afc05392ad11c82c3e3d4809a1fd9c413196b4384bfade653d075423e
3
+ size 7772
color_palette/all_one_hot_LR_sequential/new_palettes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6fcb136714cb725feccb5b6f1c66d219bdbd39ad4abcd31570665571edf5de3
3
+ size 12128
color_palette/all_one_hot_LR_sequential/original_palettes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
3
+ size 12128
color_palette/all_one_hot_LR_sequential/test_gt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
3
+ size 48128
color_palette/all_one_hot_LR_sequential/test_preds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b00a13c1ee009c195b1ffd4de3c502b4d34a72d54634e99a4e73a8d861a01292
3
+ size 48128
color_palette/app copy.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from regressor.config import config_to_use
2
+ import numpy as np
3
+ from PIL import Image, ImageFont, ImageDraw
4
+ from sklearn.linear_model import LinearRegression
5
+ from dataset import GraphDestijlDataset
6
+ from config import DataConfig
7
+ import random
8
+ import os
9
+ import torch
10
+ import gradio as gr
11
+
12
+ config = DataConfig()
13
+ model_name = config.model_name
14
+ dataset_root = config.dataset
15
+ image_folder = "img_folder"
16
+
17
+ if not os.path.exists(image_folder):
18
+ os.mkdir(image_folder)
19
+
20
+ model_weight_path = "../models/" + model_name + "/weights/best.pth"
21
+
22
+ palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
23
+ original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
24
+
25
+ graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
26
+
27
+ palette_of_the_design = [[0, 0, 0] for i in range(5)]
28
+
29
+ class Demo:
30
+ def __init__(self, graph_dataset):
31
+ self.dataset = graph_dataset
32
+
33
+ first_sample_idx = random.randint(0, len(self.dataset))
34
+ new_data, target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
35
+ self.all_node_colors = also_normal_values
36
+ self.same_indices = None
37
+ self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
38
+
39
+ def generate_img_from_palette(self, palette, canvas_size=512, is_first=False):
40
+ palette = np.array(palette).astype('int')
41
+ rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette]
42
+
43
+ if is_first:
44
+ self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette)
45
+ else:
46
+ _, unique_colors, _ = self.return_all_same_colors(palette=palette)
47
+
48
+ # assign the current palette using global keyword
49
+ global palette_of_the_design
50
+ palette_of_the_design = unique_colors
51
+
52
+ # Set the background color and create an empty PIL Image to fill with shapes and text
53
+ image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg)
54
+ # Save background image
55
+
56
+ title = "Lorem Ipsum Dolor"
57
+ undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..."
58
+
59
+ draw = ImageDraw.Draw(image)
60
+ # Set settings for the fonts
61
+ font_title = ImageFont.truetype("Arial.ttf", 32)
62
+ title_width, title_height = draw.textsize(title, font=font_title)
63
+ title_x = (canvas_size - title_width) // 2
64
+ title_y = (canvas_size - title_height) // 2 - 100
65
+
66
+ font_undertitle = ImageFont.truetype("Arial.ttf", 15)
67
+ text_width, text_height = draw.textsize(undertitle, font=font_undertitle)
68
+ undertitle_x = (canvas_size - text_width) // 2
69
+ undertitle_y = (canvas_size - text_height) // 2 - 50
70
+ # Draw titles
71
+ draw.text((title_x, title_y), title, fill=rgb_text, font=font_title)
72
+ draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle)
73
+
74
+ # Draw the circle
75
+ rad = random.randint(30, 70)
76
+ x = random.randint(400, 512-(rad+10))
77
+ y = random.randint(10, title_y-(rad+10))
78
+
79
+ draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle)
80
+
81
+ # Draw the image
82
+ for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]):
83
+ x = 512-((j+1)*60)
84
+ y = 512-((j+1)*60)
85
+
86
+ if j == 0:
87
+ rad = 80
88
+ draw.rectangle((x, y, x+rad, y+rad), fill=color)
89
+ else:
90
+ rad = 40
91
+ draw.rectangle((x, y, x+rad, y+rad), fill=color)
92
+
93
+ image.save(os.path.join("deneme.png"))
94
+
95
+ def run_model(self):
96
+ also_normal_values = np.squeeze(np.stack(also_normal_values, axis=0)) # rgb
97
+ same_indices_list, unique_colors, first_indices = return_all_same_colors(also_normal_values)
98
+
99
+ # move node_to_mask to first place in input_data and unique colors
100
+ # same_indices_list, unique_colors = rearrange_indices_list(same_indices_list, node_to_mask, unique_colors)
101
+ map_node_to_mask = -1
102
+ for i, idxs in enumerate(same_indices_list):
103
+ if node_to_mask in idxs:
104
+ map_node_to_mask = i
105
+
106
+ original_palettes.append(unique_colors.copy())
107
+ for i, indices in enumerate(same_indices_list):
108
+
109
+ if i == 0:
110
+ # update the color [0.15, 0.4908, 0.73]
111
+ selected_color = torch.Tensor([254/255, 254/255, 224/255])
112
+ cube_num_of_selected = rgb2cube(selected_color*255)
113
+ one_hot = np.zeros((64,))
114
+ one_hot[int(cube_num_of_selected)] = 1.0
115
+ node_to_recommend = map_node_to_mask
116
+ input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot)
117
+ unique_colors[map_node_to_mask] = selected_color
118
+ else:
119
+ if i == map_node_to_mask:
120
+ zeroth_bin = 0
121
+ indices = same_indices_list[zeroth_bin]
122
+ node_to_recommend = 0
123
+ input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
124
+ node_to_mask = indices[0]
125
+ else:
126
+ node_to_recommend = i
127
+ input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
128
+ node_to_mask = indices[0]
129
+
130
+ out = forward_pass(model, input_data) # input data has one-hot color features
131
+ if torch.is_tensor(node_to_mask):
132
+ node_to_mask = node_to_mask.item()
133
+ values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation
134
+ prediction = values_indices.detach().numpy()[2]
135
+ # construct a palette using unique RGB palette and one-hot representation of the prediction cube.
136
+ feature_vector = create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend)
137
+ # map cube to rgb color space using the regressor
138
+ recommendation = regressor.predict(feature_vector)[0]
139
+ # we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information.
140
+ # update the color in the palette and run the algorithm for rest of the palette.
141
+ # for that, first map the color to cube and convert to one_hot
142
+ input_data, unique_colors = update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend)
143
+ # recursively do this here.
144
+ # save the results.
145
+
146
+ def rgb2cube(color):
147
+ intervals = np.arange(0, 256, 256//4)
148
+ cube_coordinates = []
149
+ for channel in color:
150
+ i = 0
151
+ for j, value in enumerate(intervals):
152
+ if value < channel:
153
+ i = j
154
+ cube_coordinates.append(i)
155
+
156
+ cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4
157
+ return cube_num
158
+
159
+ def cube2rgb(self, cube_num):
160
+ """
161
+ Return the start of the ranges
162
+ """
163
+ cube_num = int(cube_num)
164
+ intervals = np.arange(0, 256, 256//4)
165
+ coor2 = cube_num // 16
166
+ coor1 = (cube_num - coor2*4*4) // 4
167
+ coor0 = cube_num - coor2*4*4 - coor1*4
168
+ return [intervals[coor0], intervals[coor1], intervals[coor2]]
169
+
170
+ def return_all_same_colors(self, palette):
171
+
172
+ indices_list = [[],[],[],[],[]]
173
+ unique_colors, first_indices = np.unique(palette, axis=0, return_index=True)
174
+
175
+ unique_colors = np.array(unique_colors)
176
+ all_colors = np.array(palette)
177
+
178
+ for idx, color in enumerate(unique_colors):
179
+ for node_num, element in enumerate(all_colors):
180
+ if np.equal(color, element).all():
181
+ indices_list[idx].append(node_num)
182
+
183
+ # these palettes and indices also include the masked color
184
+ return indices_list, unique_colors, first_indices
185
+
186
+ def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs):
187
+ # convert prediction to one-hot vector
188
+ cube_num_of_the_changed_color = self.rgb2cube(recommendation*255)
189
+ one_hot = np.zeros((64,))
190
+ one_hot[int(cube_num_of_the_changed_color)] = 1.0
191
+
192
+ # update the feature vector accordingly for all the same colors
193
+ for idx in indices_list[idx_to_idxs]:
194
+ input_data.x[idx, 4:] = torch.Tensor(one_hot)
195
+
196
+ # update the unique color vector
197
+ unique_rgb_palette[idx_to_idxs] = recommendation
198
+ return input_data, unique_rgb_palette
199
+
200
+
201
+ def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask):
202
+ one_hot = np.zeros((64,))
203
+ one_hot[int(cube_num)] = 1.0
204
+ removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
205
+ feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0)
206
+ return feature_vector.reshape(1, -1)
207
+
208
+ def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask):
209
+ one_hot = np.zeros((64,))
210
+ one_hot[int(cube_num)] = 1.0
211
+ removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
212
+ new_input_data = []
213
+ for color in removed_palette:
214
+ color_cube_num = self.rgb2cube(color*255)
215
+ empty_arr = np.zeros((64,))
216
+ empty_arr[int(color_cube_num)] = 1.0
217
+ new_input_data.append(empty_arr)
218
+
219
+ feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0)
220
+ return feature_vector.reshape(1, -1)
221
+
222
+ def forward_pass(self, model, data):
223
+ model.eval()
224
+ out = model(data.x, data.edge_index.long(), data.edge_weight)
225
+ return out
226
+
227
+ def train_regressor(self, train_loader):
228
+ X = []
229
+ y = []
230
+ for i, (input_data, target) in enumerate(train_loader):
231
+ input_data = np.squeeze(input_data)
232
+ target = np.squeeze(target)
233
+ X.append(input_data)
234
+ y.append(target)
235
+
236
+ X = np.stack(X, axis=0)
237
+ y = np.squeeze(np.stack(y, axis=0))
238
+
239
+ print("Before regressor train!\n")
240
+
241
+ reg = LinearRegression().fit(X, y)
242
+
243
+ return reg
244
+
245
+ def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette):
246
+ # take the node_to_mask indices to the beginning of the list
247
+ for i in range(len(indices_list)):
248
+ if node_to_mask in indices_list[i]:
249
+ index_to_pop = i
250
+
251
+ idxs = indices_list.pop(index_to_pop)
252
+ palette = unique_rgb_palette[index_to_pop]
253
+ temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0)
254
+ unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0)
255
+ return [idxs] + indices_list, unique_rgb_palette
256
+
257
+ def update_color(self, updated_color, idx):
258
+ """
259
+ Takes a color and assigns it to the palette and the image.
260
+ """
261
+ idx = int(idx)
262
+ color = updated_color[1:-1].split(",")
263
+ color = [int(num) for num in color]
264
+
265
+ global palette_of_the_design
266
+ palette_of_the_design[idx] = color
267
+
268
+ idxs_to_change = self.same_indices[idx]
269
+ for index in idxs_to_change:
270
+ self.all_node_colors[index] = torch.Tensor(color)/255
271
+
272
+ self.generate_img_from_palette(palette=[color.detach().numpy()*255 for color in self.all_node_colors])
273
+
274
+ image = Image.new("RGB", (512, 512), color=tuple(color))
275
+ main_image = Image.open("deneme.png")
276
+ return gr.Image(main_image, height=256, width=256), gr.Image(image, height=64, width=64), gr.Textbox(value=str(color), min_width=64)
277
+
278
+ if __name__ == "__main__":
279
+ demo = Demo(graph_dataset=graph_test_dataset)
280
+
281
+ # Form a gradio template to display images and update the colors.
282
+
283
+ with gr.Blocks() as project_demo:
284
+ with gr.Row():
285
+ image = Image.open("deneme.png")
286
+ design = gr.Image(image, height=256, width=256)
287
+
288
+ with gr.Row():
289
+ with gr.Column(min_width=100):
290
+ image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0]))
291
+ image1_gr = gr.Image(image1, height=64, width=64)
292
+ color1_update = gr.Textbox(value=str(palette_of_the_design[0]).replace(" ", ","), min_width=64)
293
+ color1_button = gr.Button(value="Update Color 1", min_width=64)
294
+ with gr.Column(min_width=100):
295
+ image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1]))
296
+ image2_gr = gr.Image(image2, height=64, width=64)
297
+ color2_update = gr.Textbox(value=str(palette_of_the_design[1]).replace(" ", ","), min_width=64)
298
+ color2_button = gr.Button(value="Update Color 2", min_width=64)
299
+ with gr.Column(min_width=100):
300
+ image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2]))
301
+ image3_gr = gr.Image(image3, height=64, width=64)
302
+ color3_update = gr.Textbox(value=str(palette_of_the_design[2]).replace(" ", ","), min_width=64)
303
+ color3_button = gr.Button(value="Update Color 3", min_width=64)
304
+ with gr.Column(min_width=100):
305
+ image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3]))
306
+ image4_gr = gr.Image(image4, height=64, width=64)
307
+ color4_update = gr.Textbox(value=str(palette_of_the_design[3]).replace(" ", ","), min_width=64)
308
+ color4_button = gr.Button(value="Update Color 4", min_width=64)
309
+ with gr.Column(min_width=100):
310
+ image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4]))
311
+ image5_gr = gr.Image(image5, height=64, width=64)
312
+ color5_update = gr.Textbox(value=str(palette_of_the_design[4]).replace(" ", ","), min_width=64)
313
+ color5_button = gr.Button(value="Update Color 5", min_width=64)
314
+
315
+ zero = gr.Number(value=0, visible=False)
316
+ one = gr.Number(value=1, visible=False)
317
+ two = gr.Number(value=2, visible=False)
318
+ three = gr.Number(value=3, visible=False)
319
+ four = gr.Number(value=4, visible=False)
320
+ color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update])
321
+ color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image2_gr, color2_update])
322
+ color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image3_gr, color3_update])
323
+ color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image4_gr, color4_update])
324
+ color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image5_gr, color5_update])
325
+
326
+ project_demo.launch()
color_palette/bash_scripts/training.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM_FILE=0
2
+ TO_FILE=11
3
+
4
+ file=$FROM_FILE
5
+ device_idx=2
6
+
7
+ free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i $device_idx | grep -Eo [0-9]+)
8
+
9
+ while [ $file -le $TO_FILE ]
10
+ do
11
+ # if [ $free_mem -lt 13000 ]; then
12
+ # while [ $free_mem -lt 13000 ]; do
13
+ # sleep 10
14
+ # free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i $device_idx | grep -Eo [0-9]+)
15
+ # done
16
+ # fi
17
+
18
+ echo "Running experiment for conf$file.yaml"
19
+ #nohup python train.py --config_file config/conf$file.yaml > "config/out_$file.txt" &
20
+ python evaluate.py --config_file config/conf$file.yaml
21
+ file=$(($file+1))
22
+ #sleep 10
23
+
24
+ done
color_palette/cnn_dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ from sklearn.model_selection import train_test_split
3
+ import skimage.color as scicolor
4
+ from utils import *
5
+
6
+ class PreviewDataset(Dataset):
7
+ def __init__(self, root="../destijl_dataset/rgba_dataset/",
8
+ transform=None, test=False, color_space="RGB",
9
+ input_color_space="RGB",
10
+ is_classification=False,
11
+ normalize_cielab=True,
12
+ normalize_rgb=True):
13
+
14
+ self.test = test
15
+ self.sample_filenames = os.listdir(root+"00_preview_cropped")
16
+ self.transform = transform
17
+ self.img_dir = root
18
+ self.color_space = color_space
19
+ self.is_classification = is_classification
20
+ self.input_color_space = input_color_space
21
+ self.normalize_cielab = normalize_cielab
22
+ self.normalize_rgb = normalize_rgb
23
+
24
+ self.train_filenames, self.test_filenames = train_test_split(self.sample_filenames,
25
+ test_size=0.2,
26
+ random_state=42)
27
+ def __len__(self):
28
+ if self.test:
29
+ return len(self.test_filenames)
30
+ else:
31
+ return len(self.train_filenames)
32
+
33
+ def __getitem__(self, idx):
34
+
35
+ path_idx = "{:04d}".format(idx)
36
+ img_path = os.path.join(self.img_dir, "00_preview_cropped/" + self.sample_filenames[idx])
37
+
38
+ image = np.array(Image.open(img_path))
39
+ # Convert image to lab if the input space is CIELab.
40
+ # Image is a numpy array always. Convert to tensor at the end.
41
+ if self.input_color_space == "CIELab":
42
+ image = scicolor.rgb2lab(image)
43
+ image = torch.from_numpy(image)
44
+ # if self.normalize_cielab:
45
+ # image = torch.from_numpy(image)
46
+ # image = normalize_CIELab(image)
47
+ else:
48
+ image = torch.from_numpy(image)
49
+
50
+
51
+ # Apply kmeans on RGB image always.
52
+ bg_path = os.path.join("../destijl_dataset/01_background/" + self.sample_filenames[idx])
53
+ # Most dominant color in RGB.
54
+ color = self.kmeans_for_bg(bg_path)[0]
55
+
56
+ # If output is in CIELab space but input is in RGB, convert target to CIELab also.
57
+ if self.color_space == "CIELab":
58
+ target_color = torch.squeeze(torch.tensor(RGB2CIELab(color.astype(np.int32))))
59
+ # if self.normalize_cielab:
60
+ # target_color = normalize_CIELab(target_color)
61
+ # Input and output is in RGB space or input and output is in CIELab space.
62
+ # If Input is in CIELab and output is in RGB, than this is also valid since dataset is in RGB.
63
+ else:
64
+ target_color = torch.squeeze(torch.tensor(color))
65
+
66
+ if self.is_classification:
67
+ target_color = [torch.zeros(256), torch.zeros(256), torch.zeros(256)]
68
+ target_color[0][color[0]] = 1
69
+ target_color[1][color[1]] = 1
70
+ target_color[2][color[2]] = 1
71
+
72
+ if self.transform:
73
+ # Reshape the image if not in (C, H, W) form.
74
+ if image.shape[0] != 3:
75
+ image = image.reshape(-1, image.shape[0], image.shape[1]).type("torch.FloatTensor")
76
+ # Apply the transformation
77
+ image = self.transform(image)
78
+
79
+ if self.normalize_rgb:
80
+ image /= 255
81
+
82
+ if self.color_space == "RGB" and self.normalize_rgb:
83
+ target_color /= 255
84
+
85
+ if self.normalize_cielab:
86
+ # we will only use lightness
87
+ target_color /= 100
88
+
89
+ return image, target_color
90
+
91
+ def kmeans_for_bg(self, bg_path):
92
+ image = cv2.imread(bg_path)
93
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
94
+ n_colors = 1
95
+
96
+ # Apply KMeans to the text area
97
+ pixels = np.float32(image.reshape(-1, 3))
98
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
99
+ flags = cv2.KMEANS_RANDOM_CENTERS
100
+
101
+ _, labels, palette = cv2.kmeans(pixels, n_colors, None, criteria, 10, flags)
102
+ palette = np.asarray(palette, dtype=np.int64) # RGB
103
+
104
+ return palette
color_palette/colorCNN.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ TODO:
4
+ * Make the white backgrounds transparent.
5
+ * Locate the images in the bounding boxes in preview images in white background.
6
+ Cut the images and paste to the location on the decoration layaer.
7
+ * The pasting order: Decoration will have white bg. Paste image. Paste text.
8
+ * When the white bg images are done, feed them to CNN.
9
+ * The output will be the missing color which is the background color.
10
+ * Use CIELab distances to train
11
+
12
+ """
13
+
14
+ import cv2
15
+ import numpy as np
16
+ from utils import *
17
+ from dataset_processing import ProcessedDeStijl
18
+ from PIL import Image
19
+
20
+ class DestijlProcessorCNN():
21
+
22
+ def __init__(self, data_path):
23
+ self.path_dict = {
24
+ 'preview': data_path + '/00_preview/',
25
+ 'image': data_path + '/02_image/',
26
+ 'decoration': data_path + '/03_decoration/',
27
+ 'text': data_path + '/04_text/',
28
+ }
29
+
30
+ self.rgba_path_dict = {
31
+ 'preview': data_path + '/rgba_dataset/00_preview/',
32
+ 'image': data_path + '/rgba_dataset/02_image/',
33
+ 'decoration': data_path + '/rgba_dataset/03_decoration/',
34
+ 'text': data_path + '/rgba_dataset/04_text/',
35
+ 'temporary': data_path + '/rgba_dataset/05_temporary/',
36
+ 'cropped_preview': data_path + '/rgba_dataset/00_preview_cropped/',
37
+ }
38
+
39
+ self.xml_path_dict = {
40
+ 'preview': data_path + '/xmls/00_preview/',
41
+ 'image': data_path + '/xmls/02_image/',
42
+ 'decoration': data_path + '/xmls/03_decoration/',
43
+ 'text': data_path + '/xmls/04_text/',
44
+ }
45
+
46
+ self.processed_dataset = ProcessedDeStijl("../destijl_dataset")
47
+
48
+ def whitebg_to_transparent(self, img_path, layer):
49
+
50
+ """
51
+ WORKING
52
+ """
53
+ image_bgr = cv2.imread(img_path)
54
+ image_num = img_path[-8:]
55
+ # get the image dimensions (height, width and channels)
56
+ h, w, c = image_bgr.shape
57
+ # append Alpha channel -- required for BGRA (Blue, Green, Red, Alpha)
58
+ image_bgra = np.concatenate([image_bgr, np.full((h, w, 1), 255, dtype=np.uint8)], axis=-1)
59
+ # create a mask where white pixels ([255, 255, 255]) are True
60
+ white = np.all(image_bgr == [255, 255, 255], axis=-1)
61
+ # change the values of Alpha to 0 for all the white pixels
62
+ image_bgra[white, -1] = 0
63
+ # save the image
64
+ cv2.imwrite(self.rgba_path_dict[layer]+image_num, image_bgra)
65
+
66
+ def locate_images_in_image_layer(self, idx):
67
+
68
+ method = cv2.TM_SQDIFF_NORMED
69
+ path_idx = "{:04d}".format(idx)
70
+ preview_img = cv2.imread(self.path_dict["preview"]+path_idx+".png")
71
+ preview_bboxes = VOC2bbox(self.xml_path_dict["image"]+path_idx+".xml")[1]
72
+ image_img = cv2.imread(self.path_dict["image"]+path_idx+".png")
73
+
74
+ boxes = []
75
+ design_boxes = []
76
+ for box in preview_bboxes:
77
+ xmin = box[0][0]
78
+ xmax = box[1][0]
79
+ ymin = box[0][1]
80
+ ymax = box[2][1]
81
+ cropped_img = preview_img[ymin:ymax, xmin:xmax]
82
+
83
+ if(cropped_img.shape[0] > image_img.shape[0]):
84
+ diff_x = abs(cropped_img.shape[0] - image_img.shape[0])
85
+ image_img = cv2.copyMakeBorder(image_img, diff_x//2+5, diff_x//2+5, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
86
+
87
+ if(cropped_img.shape[1] > image_img.shape[1]):
88
+ diff_y = abs(cropped_img.shape[1] - image_img.shape[1])
89
+ image_img = cv2.copyMakeBorder(image_img, 0, 0, diff_y//2+5, diff_y//2+5, cv2.BORDER_CONSTANT, value=[255, 255, 255])
90
+
91
+ result = cv2.matchTemplate(cropped_img, image_img, method)
92
+ mn,_,mnLoc,_ = cv2.minMaxLoc(result)
93
+ MPx,MPy = mnLoc
94
+ trows,tcols = cropped_img.shape[:2]
95
+ boxes.append([MPx, MPx+tcols, MPy, MPy+trows])
96
+ design_boxes.append([xmin, xmax, ymin, ymax])
97
+
98
+ self.check_boxes(design_boxes, idx)
99
+ return boxes, design_boxes
100
+
101
+ def check_boxes(self, bboxes, idx):
102
+ path_idx = "{:04d}".format(idx)
103
+ im = cv2.imread("../destijl_dataset/02_image/" + path_idx + ".png")
104
+ for box in bboxes:
105
+ # [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
106
+ xmin = box[0]
107
+ xmax = box[1]
108
+ ymin = box[2]
109
+ ymax = box[3]
110
+ cv2.rectangle(im,(xmin, ymin),(xmax, ymax),(255,0,0),2)
111
+ cv2.imwrite("check_boxes.jpg", im)
112
+
113
+ def map_image_coordinates(self, text_coordinate, design_text_coordinate, design_img_coordinates, design_size, text_size):
114
+ prev_x, prev_y = design_size
115
+ text_x, text_y = text_size
116
+
117
+ design_x, design_y = design_text_coordinate[0]
118
+ text_x, text_y = text_coordinate[0]
119
+
120
+ diff_x = text_x - design_x
121
+ diff_y = text_y - design_y
122
+
123
+ new_coordinates = []
124
+ for coordinate in design_img_coordinates:
125
+ for i in range(len(coordinate)):
126
+ if i < 2:
127
+ coordinate[i] = int(coordinate[i] + diff_x)
128
+ else:
129
+ coordinate[i] = int(coordinate[i] + diff_y)
130
+
131
+ if coordinate[i] < 0:
132
+ coordinate[i] *= -1
133
+ new_coordinates.append(coordinate)
134
+
135
+ return new_coordinates
136
+
137
+ def convert_to_min_max_coordinate(self, box):
138
+ xmin, ymin = np.min(box, axis=0)
139
+ xmax, ymax = np.max(box, axis=0)
140
+ return [int(xmin), int(xmax), int(ymin), int(ymax)]
141
+
142
+ def paste_onto_decoration_layer(self, idx):
143
+ path_idx = "{:04d}".format(idx)
144
+ preview_path = self.path_dict["preview"] + path_idx + ".png"
145
+ img_path = self.path_dict["image"] + path_idx + ".png"
146
+ decoration_path = self.path_dict["decoration"] + path_idx + ".png"
147
+ text_path = self.path_dict["text"] + path_idx + ".png"
148
+ white_bg_text_path = self.rgba_path_dict["text"] + path_idx + ".png"
149
+ white_bg_img_path = self.rgba_path_dict["image"] + path_idx + ".png"
150
+
151
+ img = cv2.imread(white_bg_img_path)
152
+ decoration_img = cv2.imread(decoration_path)
153
+ prev = cv2.imread(preview_path)
154
+
155
+ design_size = (prev.shape[0], prev.shape[1])
156
+ text_size = (decoration_img.shape[0], decoration_img.shape[1])
157
+
158
+ image_boxes, design_image_boxes = self.locate_images_in_image_layer(idx)
159
+
160
+ text_bboxes, white_bg_text_boxes, texts = self.processed_dataset.extract_text_bbox(text_path, preview_path)
161
+ text_bboxes_from_design, composed_text_palettes = self.processed_dataset.extract_text_directly(preview_path, texts)
162
+
163
+ if not text_bboxes_from_design or not white_bg_text_boxes:
164
+ pass
165
+ else:
166
+ design_text_coordinate = text_bboxes_from_design[0]
167
+ text_coordinate = white_bg_text_boxes[0]
168
+ new_image_boxes = self.map_image_coordinates(text_coordinate, design_text_coordinate, design_image_boxes, design_size, text_size)
169
+
170
+ white_bg = np.zeros( [decoration_img.shape[0], decoration_img.shape[1], 3] ,dtype=np.uint8)
171
+ white_bg.fill(255)
172
+
173
+ cv2.imwrite('bg.jpg', white_bg)
174
+
175
+ white_bg = Image.open('bg.jpg')
176
+
177
+ decoration_overlay = Image.open(self.rgba_path_dict["decoration"] + path_idx + ".png")
178
+ text_overlay = Image.open(white_bg_text_path)
179
+ white_bg.paste(decoration_overlay, mask=decoration_overlay)
180
+
181
+ for j, box in enumerate(new_image_boxes):
182
+ xmin1, xmax1, ymin1, ymax1 = box # box place on decoration
183
+ xmin2, xmax2, ymin2, ymax2 = image_boxes[j] # box place on image
184
+ cropped_img = img[ymin2:ymax2, xmin2:xmax2]
185
+
186
+ cv2.imwrite(self.rgba_path_dict["temporary"] + path_idx + ".png", cropped_img)
187
+ self.whitebg_to_transparent(self.rgba_path_dict["temporary"] + path_idx + ".png", "temporary")
188
+ cropped_img = Image.open(self.rgba_path_dict["temporary"] + path_idx + ".png")
189
+
190
+ offset = (xmin1, ymin1)
191
+ white_bg.paste(cropped_img, offset, mask=cropped_img)
192
+
193
+ white_bg.paste(text_overlay, mask=text_overlay)
194
+ white_bg.save(self.rgba_path_dict["preview"] + path_idx + ".png")
195
+
196
+ def pipeline(self):
197
+ for idx in range(550, 706):
198
+ print("Sample: ", idx)
199
+ path_idx = "{:04d}".format(idx)
200
+
201
+ img_path = self.path_dict["image"]+path_idx+".png"
202
+ text_path = self.path_dict["text"]+path_idx+".png"
203
+ decoration_path = self.path_dict["decoration"]+path_idx+".png"
204
+
205
+ self.whitebg_to_transparent(img_path, "image")
206
+ self.whitebg_to_transparent(text_path, "text")
207
+ self.whitebg_to_transparent(decoration_path, "decoration")
208
+
209
+ self.paste_onto_decoration_layer(idx)
210
+
211
+ def resize_images(self):
212
+ for idx in range(84, 705):
213
+ path_idx = "{:04d}".format(idx)
214
+
215
+ if os.path.exists(self.rgba_path_dict["preview"]+path_idx+".png"):
216
+ print(idx)
217
+
218
+ method = cv2.TM_SQDIFF_NORMED
219
+ big_image = cv2.imread(self.rgba_path_dict["preview"]+path_idx+".png")
220
+ small_image = cv2.imread(self.path_dict["preview"]+path_idx+".png")
221
+ if(small_image.shape[0] > big_image.shape[0]):
222
+ diff_x = abs(big_image.shape[0] - small_image.shape[0])
223
+ big_image = cv2.copyMakeBorder(big_image, diff_x//2+5, diff_x//2+5, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
224
+
225
+ if(small_image.shape[1] > big_image.shape[1]):
226
+ diff_y = abs(big_image.shape[1] - small_image.shape[1])
227
+ big_image = cv2.copyMakeBorder(big_image, 0, 0, diff_y//2+5, diff_y//2+5, cv2.BORDER_CONSTANT, value=[255, 255, 255])
228
+
229
+ result = cv2.matchTemplate(small_image, big_image, method)
230
+ mn,_,mnLoc,_ = cv2.minMaxLoc(result)
231
+ MPx,MPy = mnLoc
232
+ trows,tcols = small_image.shape[:2]
233
+ box = [MPx, MPx+tcols, MPy, MPy+trows]
234
+ new_img = big_image[box[2]:box[3], box[0]:box[1]]
235
+ cv2.imwrite(self.rgba_path_dict["cropped_preview"]+path_idx+".png", new_img)
236
+
237
+ processor = DestijlProcessorCNN("../destijl_dataset")
238
+ processor.resize_images()
color_palette/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class DataConfig:
5
+
6
+ dataset = "shape_dataset_circle_image/"
7
+ model_name = "ColorAttentionCircleImageLayer_random_mask"
8
+ data_type = "processed_rgb_toy_dataset_circle_image_color_and_layer"
9
+
10
+ # dataset = "../shape_dataset_lightness_circle/"
11
+ # model_name = "ColorAttentionLightnessCircle_random_mask_class"
12
+ # data_type = "processed_rgb_toy_dataset_lightness_circle_color_and_layer"
13
+
14
+ # dataset = "../destijl_dataset/"
15
+ # model_name = "ColorGNN_random_mask_new"
16
+ # data_type = "processed_rgb_color_and_layer"
17
+
18
+ feature_size = 68
19
+ loss_function = "CrossEntropy"
20
+
21
+ device = "cpu"
22
+ lr = 0.005
23
+ batch_size = 1
24
+ weight_decay = 0
25
+ num_epoch = 300
26
+
27
+ node_to_mask = -1
28
+ is_classification = True
29
+ layers_to_consider = ["background", "text", "image"]
color_palette/config/conf.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "ColorAttentionColor_Layer"
2
+ data_type: "processed_rgb_toy_dataset_color_and_layer"
3
+ feature_size: 4
4
+
5
+ threshold_for_neighbours: 1
6
+
7
+ loss_function: "MSE"
8
+
9
+ device: "cuda:2"
10
+ lr: 0.01
11
+ batch_size: 16
12
+ weight_decay: 0
13
+ num_epoch: 300
14
+
15
+ node_to_mask: -1
color_palette/config/confCNN.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 16
2
+ color_space: CIELab
3
+ device: cuda:1
4
+ input_color_space: RGB
5
+ input_size: 512
6
+ is_classification: false
7
+ loss_function: MSE
8
+ out_features: 1
9
+ lr: 0.05
10
+ map_outputs: true
11
+ model_name: ColorCNN_lightness_RGB_normalized
12
+ num_epoch: 150
13
+ step_size: 20
14
+ weight_decay: 0
15
+ normalize_rgb: True
16
+ normalize_cielab: True
color_palette/config/grid_search_conf_generator.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ lr_list = [0.01, 0.005, 0.001]
4
+ weight_decay = [0, 0.01, 0.001, 0.0001]
5
+
6
+ count = 0
7
+ for i, lr in enumerate(lr_list):
8
+ for j, wd in enumerate(weight_decay):
9
+ with open("conf"+str(count)+".yaml", "w") as file:
10
+ data = {
11
+ "model_name": "ColorGNNEmbedding_lr"+str(lr)+"_wd"+str(wd),
12
+ "data_type": "processed_rgb",
13
+ "feature_size": 1005,
14
+
15
+ "device": "cuda:2",
16
+ "lr": lr,
17
+ "batch_size": 1,
18
+ "weight_decay": wd,
19
+ "num_epoch": 150,
20
+
21
+ "dataset_root": "../destijl_dataset",
22
+ }
23
+ documents = yaml.dump(data, file)
24
+ count+=1
color_palette/cube_num_one_hot_LR/test_gt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
3
+ size 48128
color_palette/cube_num_one_hot_LR/test_preds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afaa3745f1579f54e69178c3d46aba4a9c23625107f65d0fb175e91e0bc6c243
3
+ size 48128
color_palette/cube_num_one_hot_LR/test_preds_graph.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73bc6274b537420ae855cc9c5c0feeeeafebbbe52f0ad2db96be9304760e41d8
3
+ size 4184
color_palette/cube_num_one_hot_LR/test_rgb_colors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb3b9f43eca49006a8d6a6c9b3943a545673ce5bf0b1e9c7dcf8ea3c8526fd4d
3
+ size 14324
color_palette/cube_num_one_hot_LR_sequential/new_palettes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aca40e932ed33182fc3cbf5248dc03c08fc4081781529d462cdede5c1b21b106
3
+ size 12128
color_palette/cube_num_one_hot_LR_sequential/new_palettes_purple.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7540033b68848a3a910a55cb1896055f087171ebbddea84cd7024dca6b8f3a6e
3
+ size 12128
color_palette/cube_num_one_hot_LR_sequential/original_palettes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
3
+ size 12128
color_palette/cube_num_one_hot_LR_sequential/original_palettes_purple.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9342282a3717f08e759cce63bf37852392a5996b879a878aab12de1ad089081b
3
+ size 12128
color_palette/cube_num_one_hot_LR_sequential/test_gt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e5d5394bf0dcce227956a9a66b46263a0102492f132f6c731ee020e703ec3c8
3
+ size 48128
color_palette/cube_num_one_hot_LR_sequential/test_preds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afaa3745f1579f54e69178c3d46aba4a9c23625107f65d0fb175e91e0bc6c243
3
+ size 48128
color_palette/dataset.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_geometric
3
+ from torch_geometric.data import Dataset, Data
4
+ from sklearn.model_selection import train_test_split
5
+ import os
6
+ import random
7
+ import numpy as np
8
+ import yaml
9
+ from color_palette.utils import *
10
+ from skimage.color import rgb2lab, lab2rgb
11
+ from color_palette.config import *
12
+ import math
13
+
14
+ # with open("config/conf.yaml", 'r') as f:
15
+ # config = yaml.load(f, Loader=yaml.FullLoader)
16
+
17
+ config = DataConfig()
18
+
19
+ data_type = config.data_type
20
+ mask_type = -1
21
+ dataset_root = config.dataset
22
+
23
+ print(f"Torch version: {torch.__version__}")
24
+ print(f"Cuda available: {torch.cuda.is_available()}")
25
+ print(f"Torch geometric version: {torch_geometric.__version__}")
26
+
27
+ class GraphDestijlDataset(Dataset):
28
+ def __init__(self, root, test=False, transform=None, pre_transform=None, cube_mapping=False, square_label=False):
29
+ """
30
+ root = Where the dataset should be stored. This folder is split
31
+ into raw_dir (downloaded dataset) and processed_dir (processed data).
32
+ """
33
+ self.test = test
34
+ self.sample_filenames = os.listdir(root + data_type +'/')
35
+ self.processed_data_dir = root + data_type + '/'
36
+ self.square_label = square_label
37
+ self.map_to_cube = cube_mapping
38
+ self.cube_size = 4
39
+
40
+ # If you want to use less data than the whole dataset, you can specify the range here.
41
+ # Than it loads only samples up to that sample.
42
+ self.sample_filenames = ["data_{:04d}.pt".format(idx) for idx in range(0, 1000)]
43
+
44
+ # Train test filenames.
45
+ self.train_filenames, self.test_filenames = train_test_split(self.sample_filenames,
46
+ test_size=0.2,
47
+ random_state=42)
48
+ super().__init__(root, transform, pre_transform)
49
+
50
+ @property
51
+ def raw_file_names(self):
52
+ return "empty"
53
+
54
+ @property
55
+ def processed_file_names(self):
56
+ """ If these files are found in raw_dir, processing is skipped"""
57
+
58
+ if self.test:
59
+ return self.test_filenames
60
+ else:
61
+ return self.train_filenames
62
+
63
+ def download(self):
64
+ pass
65
+
66
+ def process(self):
67
+ pass
68
+
69
+ def return_all_same_colors(self, input_data):
70
+ indices_list = [[],[],[],[],[]]
71
+ unique_colors = torch.from_numpy(np.unique(input_data.x[:, 4:], axis=0))
72
+ all_colors = input_data.x[:, 4:]
73
+ for idx, color in enumerate(unique_colors):
74
+ for node_num, element in enumerate(all_colors):
75
+ if torch.equal(color, element):
76
+ indices_list[idx].append(node_num)
77
+
78
+
79
+ return indices_list, unique_colors
80
+
81
+
82
+ def test_train_mask(self, data):
83
+
84
+ '''
85
+ Input: graph data
86
+
87
+ Mask the color of one node. The ground truth color is the last 3 dimension of the feature vector.
88
+ Data is saved as RGB.
89
+ If you want you can convert unnormalized RGB ground truth color to Lab.
90
+ (Conversion is done using COLORMATH)
91
+ Put mask on a random node's color information by setting that color to [0, 0, 0].
92
+
93
+ Return: new_data with masked RGB colors, color_to_hide in lab, node_to_mask scalar
94
+ '''
95
+
96
+ # Take number of nodes
97
+ n_nodes = len(data.x)
98
+ if mask_type == -1:
99
+ # Chose the color to mask randomly
100
+ node_to_mask = random.randint(0, n_nodes-1)
101
+ #node_to_mask = n_nodes-1
102
+ else:
103
+ # Mask the red each time
104
+ node_to_mask = n_nodes-2
105
+ # If you chosed a folder that has processed_rgb name in it, then all the colors are stores in (0, 255) RGB.
106
+ feature_vector = data.x
107
+ # This is our target.
108
+
109
+ also_normal_values = []
110
+
111
+ for color in enumerate(feature_vector[:, -3:].clone()):
112
+ also_normal_values.append(color[1])
113
+
114
+ color_to_hide = feature_vector[node_to_mask, -3:].clone()
115
+
116
+ if self.map_to_cube:
117
+ # print("RGB")
118
+ # print(color_to_hide*255)
119
+ color_to_hide = self.cube_mapping(color_to_hide*255)
120
+ # print("CONVERSION")
121
+ # print(color_to_hide)
122
+
123
+ # Conversion to cielab. I do not use it anymore.
124
+ #color_to_hide = torch.tensor(RGB2CIELab(color_to_hide.numpy().astype(np.int32)))
125
+ #color_to_hide = torch.tensor(rgb2lab(color_to_hide.numpy().astype(np.int32)))
126
+ # Set node to mask in feature vector to zero.
127
+ feature_vector[node_to_mask, -3:] = torch.Tensor([0, 0, 0]) #torch.Tensor([0.9, 0.1, 0.1])
128
+
129
+ # Seperate square and circle labels.
130
+
131
+ if self.square_label:
132
+
133
+ add_label = torch.zeros([5,1])
134
+ add_label[4, 0] = 1
135
+ labels = feature_vector[:, :3]
136
+ labels[4, 2] = 0
137
+ labels = torch.cat((labels, add_label), dim=1)
138
+ feature_vector = torch.cat((labels, feature_vector[:, -3:]), dim=1)
139
+
140
+ # Assing the new feature vector to the graph.
141
+
142
+ new_data = data.clone()
143
+ new_data.x = feature_vector
144
+
145
+
146
+ # This code below is used if we want to apply a threshold while adding edges.
147
+ # It just removes the edges with a higher distance than the threshold.
148
+
149
+ # new_edge_weight = []
150
+ # new_edge_index = []
151
+ # for k, edge in enumerate(new_data.edge_weight):
152
+ # if edge.item() < threshold:
153
+ # new_edge_weight.append(edge.item())
154
+ # new_edge_index.append([new_data.edge_index[0][k], new_data.edge_index[1][k]])
155
+
156
+ # new_data.edge_index = torch.Tensor(new_edge_index).T
157
+ # new_data.edge_weight = torch.Tensor(new_edge_weight)
158
+ #print("New calculations")
159
+ #print(new_data.edge_index, new_data.edge_weight)
160
+
161
+ if self.map_to_cube:
162
+ cube_num_list = []
163
+ cube_colors = torch.zeros((feature_vector.shape[0], int(math.pow(self.cube_size, 3))))
164
+ for j, color in enumerate(feature_vector[:, -3:]):
165
+ if j != node_to_mask:
166
+ cube_num = self.cube_mapping(color*255)
167
+ cube_colors[j][cube_num] = 1
168
+ cube_num_list.append(cube_num)
169
+
170
+ new_data.x = torch.cat((feature_vector[:, :-3], cube_colors), dim=1)
171
+ new_data.y = cube_num_list
172
+
173
+ return new_data, color_to_hide, node_to_mask, also_normal_values
174
+
175
+ def len(self):
176
+ if self.test:
177
+ return len(self.test_filenames)
178
+ else:
179
+ return len(self.train_filenames)
180
+
181
+ def get(self, idx):
182
+ if self.test:
183
+ data = torch.load(self.processed_data_dir+self.test_filenames[idx])
184
+ new_data, target_color, node_to_mask, also_normal_values = self.test_train_mask(data)
185
+ else:
186
+ data = torch.load(self.processed_data_dir+self.train_filenames[idx])
187
+ new_data, target_color, node_to_mask, also_normal_values = self.test_train_mask(data)
188
+ return new_data, target_color, node_to_mask, also_normal_values
189
+
190
+ def cube_mapping(self, color):
191
+ intervals = np.arange(0, 256, 256//4)
192
+ cube_coordinates = []
193
+ for channel in color:
194
+ i = 0
195
+ for j, value in enumerate(intervals):
196
+ if value < channel:
197
+ i = j
198
+ cube_coordinates.append(i)
199
+
200
+ cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*self.cube_size + cube_coordinates[2]*self.cube_size*self.cube_size
201
+ return cube_num
202
+
203
+ def cube2rgb(self, cube_num):
204
+ """
205
+ Return the start of the ranges
206
+ """
207
+ cube_num = int(cube_num)
208
+ intervals = np.arange(0, 256, 256//4)
209
+ coor2 = cube_num // 16
210
+ coor1 = (cube_num - coor2*self.cube_size*self.cube_size) // 4
211
+ coor0 = cube_num - coor2*self.cube_size*self.cube_size - coor1*self.cube_size
212
+ return [intervals[coor0], intervals[coor1], intervals[coor2]]
213
+
214
+ if __name__ == '__main__':
215
+ dataset_obj = GraphDestijlDataset(root=dataset_root)
color_palette/dataset_processing.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.data import Dataset
4
+
5
+ import os
6
+ from utils import *
7
+
8
+ from paddleocr import PaddleOCR, draw_ocr
9
+
10
+ from PIL import Image, ImageFont
11
+ import numpy as np
12
+ import cv2
13
+ from sklearn.cluster import KMeans
14
+ from sklearn.metrics import silhouette_score
15
+ from matplotlib.colors import hsv_to_rgb, rgb_to_hsv
16
+
17
+ from torchvision.models import resnet50, ResNet50_Weights
18
+ from model.CNN import Autoencoder
19
+
20
+ from model.graph import DesignGraph
21
+ from config import *
22
+
23
+ class ProcessedDeStijl(Dataset):
24
+ def __init__(self, data_path):
25
+ self.path = data_path
26
+ self.path_dict = {
27
+ 'preview': data_path + '/00_preview/',
28
+ 'background': data_path + '/01_background/',
29
+ 'image': data_path + '/02_image/',
30
+ 'decoration': data_path + '/03_decoration/',
31
+ 'text': data_path + '/04_text/',
32
+ 'theme': data_path + '/05_theme/'
33
+ }
34
+
35
+ self.data_path = data_path
36
+ self.dataset_size = len(next(os.walk(self.path_dict['preview']))[2])
37
+ self.ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False)
38
+
39
+ self.criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
40
+ self.flags = cv2.KMEANS_RANDOM_CENTERS
41
+
42
+ #self.layers = ['background', 'text', "decoration"] # Take this from config file later
43
+ self.layers = ['background', 'text', "decoration"]
44
+
45
+ self.pretrained_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
46
+ # self.pretrained_model = Autoencoder()
47
+ # self.pretrained_model.load_state_dict(torch.load("../CNN_models/CNNAutoencoder/weights/best.pth")["state_dict"])
48
+
49
+ def len(self):
50
+ return self.dataset_size
51
+
52
+ def get(self, idx):
53
+ '''
54
+ Return a graph object based on the information
55
+ '''
56
+ pass
57
+
58
+ ######### RUNTIME EXTRACTION ###########
59
+
60
+
61
+ ######### PROCESSING AND ANNOTATING THE DATASET ###########
62
+
63
+ def extract_text_bbox(self, img_path, preview_image_path):
64
+ '''
65
+ Input: path to the text image
66
+ Extract text using paddleOCR.
67
+ Crop text from bounding box.
68
+ Extract colors using Kmeans inside the bbox.
69
+ Return the dominant color and the position.
70
+
71
+ DONE: Try to combine very close lines as paragraph bbox.
72
+ If the the distance between two bbox is smaller than the bbox height and color is the same,
73
+ we can group them as paragraphs.
74
+
75
+ Return: text color palettes, dominant colors for each text and position list (as bboxes).
76
+ '''
77
+ # Parameters for KMeans.
78
+ n_colors = 3
79
+
80
+ result = self.ocr.ocr(img_path, cls=True)[0]
81
+
82
+ image = Image.open(img_path).convert('RGB')
83
+ boxes = [line[0] for line in result]
84
+ texts = [line[1][0] for line in result]
85
+ image = cv2.imread(img_path)
86
+ preview_image = cv2.imread(preview_image_path)
87
+
88
+ palettes = []
89
+ dominants = []
90
+ new_bboxes = []
91
+
92
+ # Run KMeans for each text object
93
+ for bbox in boxes:
94
+ # Crop the text area
95
+ x, y = int(bbox[0][0]), int(bbox[0][1])
96
+ z, t = int(bbox[2][0]), int(bbox[2][1])
97
+ cropped_image = image[y:t, x:z]
98
+
99
+ # Do template matching to find the places at the actual image because not every image has the same size.
100
+ method = cv2.TM_SQDIFF_NORMED
101
+ result = cv2.matchTemplate(cropped_image, preview_image, method)
102
+ mn,_,mnLoc,_ = cv2.minMaxLoc(result)
103
+ MPx,MPy = mnLoc
104
+ trows,tcols = cropped_image.shape[:2]
105
+ # --> left top, right top, right bottom, left bottom
106
+ bbox = [[MPx,MPy], [MPx+tcols, MPy], [MPx+tcols, MPy+trows], [MPx, MPy+trows]]
107
+ new_bboxes.append(bbox)
108
+
109
+ return new_bboxes, boxes, texts
110
+
111
+ def compose_paragraphs(self, text_bboxes, text_palettes):
112
+
113
+ '''
114
+ Compose text data into paragraphs.
115
+ Return: Grouped indices of detected text elements.
116
+ '''
117
+
118
+ num_text_boxes = len(text_bboxes)
119
+ if num_text_boxes == 0:
120
+ return False
121
+ composed_text_idxs = [[0]]
122
+ for i in range(num_text_boxes-1):
123
+ palette1 = text_palettes[i]
124
+ palette2 = text_palettes[i+1]
125
+ if np.array_equal(palette1, palette2):
126
+ bbox1 = text_bboxes[i]
127
+ bbox2 = text_bboxes[i+1]
128
+ height1 = bbox1[0][1] - bbox1[3][1]
129
+ height2 = bbox2[0][1] - bbox2[3][1]
130
+ if abs(bbox1[0][1]-bbox2[0][1]) <= abs(height1)+30:
131
+ if i != 0 and i not in composed_text_idxs[-1]:
132
+ composed_text_idxs.append([i])
133
+ composed_text_idxs[-1].append(i+1)
134
+ else:
135
+ if i != 0 and i not in composed_text_idxs[-1]:
136
+ composed_text_idxs.append([i])
137
+ if i == num_text_boxes-2:
138
+ composed_text_idxs.append([i+1])
139
+ else:
140
+ if i != 0 and i not in composed_text_idxs[-1]:
141
+ composed_text_idxs.append([i])
142
+ if i == (num_text_boxes-2):
143
+ composed_text_idxs.append([i+1])
144
+
145
+ return composed_text_idxs
146
+
147
+ def merge_bounding_boxes(self, composed_text_idxs, bboxes):
148
+ '''
149
+ openCV --> x: left-to-right, y: top--to-bottom
150
+ bbox coordinates --> [[256.0, 1105.0], [1027.0, 1105.0], [1027.0, 1142.0], [256.0, 1142.0]]
151
+ --> left top, right top, right bottom, left bottom
152
+
153
+ TODO: Also return color palettes for each merged box.
154
+ '''
155
+
156
+ biggest_borders = []
157
+ if len(bboxes) == 0:
158
+ return biggest_borders
159
+ for idxs in composed_text_idxs:
160
+ smallest_x = smallest_y = 10000
161
+ biggest_y = biggest_x = 0
162
+ if len(idxs) > 1:
163
+ for idx in idxs:
164
+ bbox = bboxes[idx]
165
+ bbox_smallest_x, bbox_smallest_y = np.min(bbox, axis=0)
166
+ bbox_biggest_x, bbox_biggest_y = np.max(bbox, axis=0)
167
+
168
+ if smallest_x > bbox_smallest_x:
169
+ smallest_x = bbox_smallest_x
170
+ if smallest_y > bbox_smallest_y:
171
+ smallest_y = bbox_smallest_y
172
+ if biggest_x < bbox_biggest_x:
173
+ biggest_x = bbox_biggest_x
174
+ if biggest_y < bbox_biggest_y:
175
+ biggest_y = bbox_biggest_y
176
+
177
+ biggest_border = [[smallest_x, smallest_y], [biggest_x, smallest_y], [biggest_x, biggest_y], [smallest_x, biggest_y]]
178
+ biggest_borders.append(biggest_border)
179
+ else:
180
+ biggest_borders.append(bboxes[idxs[0]])
181
+ return biggest_borders
182
+
183
+ def mini_kmeans(self, biggest_border, n_colors, image):
184
+ # for text
185
+ x, y = int(biggest_border[0][0]), int(biggest_border[0][1])
186
+ z, t = int(biggest_border[2][0]), int(biggest_border[2][1])
187
+ cropped_image = image[y:t, x:z]
188
+ pixels = np.float32(cropped_image.reshape(-1, 3))
189
+ _, labels, palette = cv2.kmeans(pixels, n_colors, None, self.criteria, 10, self.flags)
190
+ palette = np.asarray(palette, dtype=np.int64)
191
+
192
+ _, counts = np.unique(labels, return_counts=True)
193
+ color = palette[np.argmin(counts)]
194
+ return color
195
+
196
+
197
+ def extract_text_directly(self, img_path, white_bg_texts):
198
+ n_colors = 2
199
+
200
+ result = self.ocr.ocr(img_path, cls=True)[0]
201
+
202
+ image = Image.open(img_path).convert('RGB')
203
+ boxes = [line[0] for line in result]
204
+ texts = [line[1][0].replace(" ", "").lower() for line in result]
205
+ white_bg_texts = [elem.replace(" ", "").lower() for elem in white_bg_texts]
206
+ image = cv2.imread(img_path)
207
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
208
+ same_idxs = []
209
+ new_boxes = []
210
+
211
+ composed_text_palettes = []
212
+ for j, elem in enumerate(white_bg_texts):
213
+ for i, text in enumerate(texts):
214
+ if similar(elem, text) > 0.85:
215
+ new_boxes.append(boxes[i])
216
+ biggest_border = boxes[i]
217
+ composed_text_palettes.append(self.mini_kmeans(biggest_border, n_colors, image))
218
+ elif i+1 != len(texts):
219
+ if similar(elem, text + texts[i+1]) > 0.85:
220
+ # merge boxes
221
+ bboxes = [boxes[i], boxes[i+1]]
222
+ smallest_x = 1000
223
+ smallest_x = smallest_y = 10000
224
+ biggest_y = biggest_x = 0
225
+ for idx in [0, 1]:
226
+ bbox = bboxes[idx]
227
+ bbox_smallest_x, bbox_smallest_y = np.min(bbox, axis=0)
228
+ bbox_biggest_x, bbox_biggest_y = np.max(bbox, axis=0)
229
+
230
+ if smallest_x > bbox_smallest_x:
231
+ smallest_x = bbox_smallest_x
232
+ if smallest_y > bbox_smallest_y:
233
+ smallest_y = bbox_smallest_y
234
+ if biggest_x < bbox_biggest_x:
235
+ biggest_x = bbox_biggest_x
236
+ if biggest_y < bbox_biggest_y:
237
+ biggest_y = bbox_biggest_y
238
+
239
+ biggest_border = [[smallest_x, smallest_y], [biggest_x, smallest_y], [biggest_x, biggest_y], [smallest_x, biggest_y]]
240
+ new_boxes.append(biggest_border)
241
+ composed_text_palettes.append(self.mini_kmeans(biggest_border, n_colors, image))
242
+
243
+ return new_boxes, composed_text_palettes
244
+
245
+ def extract_decor_elements(self, decoration_path, preview_path):
246
+ # Determine the number of dominant colors
247
+ num_colors = 6
248
+
249
+ # Load the image
250
+ image = cv2.imread(decoration_path)
251
+
252
+ # Convert the image to the RGB color space
253
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
254
+ image2 = image.copy()
255
+
256
+ # Reshape the image to a 2D array of pixels
257
+ pixels = image.reshape(-1, 3)
258
+
259
+ # Apply K-means clustering with the determined number of colors
260
+ kmeans = KMeans(n_clusters=num_colors)
261
+ kmeans.fit(pixels)
262
+
263
+ # Get the RGB values of the dominant colors
264
+ colors = kmeans.cluster_centers_.astype(int)
265
+
266
+ # Convert the colors to the HSV color space
267
+ hsv_colors = []
268
+
269
+ for i, color in enumerate(colors):
270
+ x, y, z = color
271
+ if not (252 < x < 256 and 252 < y < 256 and 252 < z < 256):
272
+ x, y, z = rgb_to_hsv([x/255, y/255, z/255])
273
+ hsv_colors.append([x*180, y*255, z*255])
274
+ # Convert the image to the HSV color space
275
+ hsv_image = cv2.cvtColor(image2, cv2.COLOR_RGB2HSV)
276
+
277
+ # Create masks for each dominant color
278
+ masks = []
279
+ hsv_colors = np.asarray(hsv_colors, dtype=np.int32)
280
+
281
+ colors = []
282
+ for i in range(len(hsv_colors)):
283
+
284
+ h, s, v = hsv_colors[i, :]
285
+ lower_color = hsv_colors[i, :] - np.array([10, 50, 50])
286
+ upper_color = hsv_colors[i, :] + np.array([10, 255, 255])
287
+ mask = cv2.inRange(hsv_image, lower_color, upper_color)
288
+ colors.append([h,s,v])
289
+ masks.append(mask)
290
+
291
+ # Find contours in each mask
292
+ contours = []
293
+ for mask in masks:
294
+ contours_color, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
295
+ contours.append(contours_color)
296
+
297
+ # Draw bounding boxes around the shapes
298
+ image_with_boxes = image.copy()
299
+ bboxes = []
300
+ for i, contour_color in enumerate(contours):
301
+ for contour in contour_color:
302
+ x, y, w, h = cv2.boundingRect(contour)
303
+ # left top, right top, right bottom, left bottom
304
+ bboxes.append([[x,y], [x+w, y], [x+w, y+h], [x,y+h]])
305
+
306
+ new_bboxes = delete_too_small_bboxes(np.asarray(bboxes))
307
+ return colors, new_bboxes
308
+
309
+ def map_decoration_coordinates(self, design_text_coordinate, text_coordinate, decoration_coordinates, prev_size, text_size):
310
+ # --> [[256.0, 1105.0], [1027.0, 1105.0], [1027.0, 1142.0], [256.0, 1142.0]]
311
+ # --> left top, right top, right bottom, left bottom
312
+
313
+ prev_x, prev_y = prev_size
314
+ text_x, text_y = text_size
315
+
316
+ design_x, design_y = design_text_coordinate[0]
317
+ text_x, text_y = text_coordinate[0]
318
+
319
+ diff_x = text_x - design_x
320
+ diff_y = text_y - design_y
321
+
322
+ new_coordinates = []
323
+ for coordinate in decoration_coordinates:
324
+ new_coor = []
325
+ for elem in coordinate:
326
+ new_coor.append([elem[0]-diff_x, elem[1]-diff_y])
327
+ new_coordinates.append(new_coor)
328
+
329
+ return new_coordinates
330
+
331
+ def extract_image(self, preview_path, image_path):
332
+ '''
333
+ Use Template Matching the put a bounding box around the main image. Use it as the position.
334
+ Extract colors using KMeans.
335
+ Return: image color palettes and position list (as bboxes).
336
+ '''
337
+
338
+ preview_image = cv2.imread(preview_path)
339
+ preview_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
340
+ cropped_image_path = trim_image(image_path, "02_image")
341
+ image = cv2.imread(cropped_image_path)
342
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
343
+
344
+ if image.shape[0] > preview_image.shape[0]:
345
+ diff_x = image.shape[0] - preview_image.shape[0]
346
+ image = image[(diff_x//2+1):image.shape[0]-(diff_x//2+1), :]
347
+
348
+ if image.shape[1] > preview_image.shape[1]:
349
+ diff_y = image.shape[1] - preview_image.shape[1]
350
+ image = image[:, (diff_y//2+1):image.shape[1]-(diff_y//2+1)]
351
+
352
+ method = cv2.TM_SQDIFF_NORMED
353
+ result = cv2.matchTemplate(image, preview_image, method)
354
+ mn,_,mnLoc,_ = cv2.minMaxLoc(result)
355
+ MPx,MPy = mnLoc
356
+ trows,tcols = image.shape[:2]
357
+ bbox = [[MPx,MPy], [MPx+tcols,MPy+trows]]
358
+ cropped_image = preview_image[MPx:MPx+tcols, MPy:MPy+trows]
359
+
360
+ pixels = np.float32(image.reshape(-1, 3))
361
+
362
+ n_colors = 6
363
+
364
+ _, labels, palette = cv2.kmeans(pixels, n_colors, None, self.criteria, 10, self.flags)
365
+ palette = np.asarray(palette, dtype=np.int64)
366
+
367
+ return [bbox], palette
368
+
369
+ def annotate_dataset(self):
370
+
371
+ for idx in range(388, self.dataset_size):
372
+ print("CURRENTLY AT: ", idx)
373
+ path_idx = "{:04d}".format(idx)
374
+ preview = self.path_dict['preview'] + path_idx + '.png'
375
+ decoration = self.path_dict['decoration'] + path_idx + '.png'
376
+ image = self.path_dict['image'] + path_idx + '.png'
377
+ text = self.path_dict['text'] + path_idx + '.png'
378
+
379
+ text_bboxes, white_bg_text_boxes, texts = self.extract_text_bbox(text, preview)
380
+ text_bboxes_from_design, composed_text_palettes = self.extract_text_directly(preview, texts)
381
+ composed_text_idxs = self.compose_paragraphs(text_bboxes_from_design, composed_text_palettes)
382
+ merged_bboxes = []
383
+ if composed_text_idxs != False:
384
+ merged_bboxes = self.merge_bounding_boxes(composed_text_idxs, text_bboxes_from_design)
385
+ image_bboxes, image_palette = self.extract_image(preview, image)
386
+
387
+ #decoration_hsv_xpalettes, decoration_bboxes = self.extract_decor_elements(decoration, preview)
388
+ # image_prev = cv2.imread(preview)
389
+ # image_text = cv2.imread(text)
390
+ #mapped_decoration_bboxes = self.map_decoration_coordinates(text_bboxes_from_design[0], white_bg_text_boxes[0], decoration_bboxes, (image_prev.shape[0], image_prev.shape[1]), (image_text.shape[0], image_text.shape[1]))
391
+
392
+ #create_xml("../destijl_dataset/xmls/03_decoration", path_idx+".xml", mapped_decoration_bboxes)
393
+ if len(merged_bboxes) == 0:
394
+ create_xml(self.data_path+"/xmls/04_text", path_idx+".xml", [[[0,0],[0,0],[0,0],[0,0]]])
395
+ else:
396
+ create_xml(self.data_path+"/xmls/04_text", path_idx+".xml", merged_bboxes)
397
+ create_xml(self.data_path+"/xmls/02_image", path_idx+".xml", image_bboxes)
398
+
399
+ def process_dataset(self, idx):
400
+ '''
401
+ Process each node. Construct graph features and save the features as pt files.
402
+ This code should be used after we have an annotated dataset.
403
+ '''
404
+ path_idx = "{:04d}".format(idx)
405
+
406
+ img_path_dict = {
407
+ 'preview': self.data_path + '/00_preview/' + path_idx + '.png',
408
+ 'background': self.data_path + '/01_background/' + path_idx + '.png',
409
+ 'image': self.data_path + '/02_image/' + path_idx + '.png',
410
+ 'text': self.data_path + '/04_text/' + path_idx + '.png',
411
+ 'decoration': self.data_path + '/03_decoration/' + path_idx + '.png',
412
+ }
413
+
414
+ annotation_path_dict = {
415
+ 'preview': self.data_path + '/xmls' +'/00_preview/' + path_idx + '.xml',
416
+ 'image': self.data_path + '/xmls' + '/02_image/' + path_idx + '.xml',
417
+ 'text': self.data_path + '/xmls' + '/04_text/' + path_idx + '.xml',
418
+ 'decoration': self.data_path + '/xmls' + '/03_decoration/' + path_idx + '.xml',
419
+ }
420
+
421
+ all_bboxes = {
422
+ 'image':[],
423
+ 'background':[],
424
+ "decoration":[],
425
+ 'text':[]
426
+ }
427
+ all_images = {
428
+ 'image':[],
429
+ 'background':[],
430
+ "decoration":[],
431
+ 'text':[]
432
+ }
433
+
434
+ # For each layer:
435
+ # * save all bounding boxes to all_bboxes
436
+ # * save all paths of images in which we extract the objects from --> to all_images
437
+ # * We generally extract all images from the preview image so that path is preview image path
438
+ # * We save this information to use in DesignGraph. It extracts colors from bounding boxes
439
+ # using the image we want to extract them from.
440
+
441
+ for i, layer in enumerate(self.layers):
442
+ # Check what is the layer, save information accordingly to dictionaries
443
+ if layer == "background":
444
+ # load image as CV image
445
+ self.preview_img = cv2.imread(img_path_dict[layer])
446
+ # save the layer image path
447
+ img = img_path_dict[layer]
448
+ all_images[layer] = img
449
+ # since it is background, just add a trivial bbox. This is not used
450
+ all_bboxes[layer] = [[[0, 0], [self.preview_img.shape[0], 0], [self.preview_img.shape[0], self.preview_img.shape[1]], [0, self.preview_img.shape[1]]]]
451
+ else:
452
+ if layer == 'text':
453
+ # get the preview path since we extract text directly from preview image
454
+ img_path = img_path_dict['preview']
455
+ # get annotations from xml
456
+ filename, bboxes = VOC2bbox(annotation_path_dict[layer])
457
+ # assign bounding boxes and image path to extract colors later
458
+ all_bboxes[layer] = bboxes
459
+ all_images[layer] = img_path
460
+ elif layer == 'image' or layer == "decoration":
461
+ # same logic is applied as text
462
+ img_path = img_path_dict['preview']
463
+ self.img_img = cv2.imread(img_path)
464
+ filename, bboxes = VOC2bbox(annotation_path_dict[layer])
465
+ """
466
+ This comment below is for checking whether the annotation boxes work for
467
+ each bounding box. It saves the annotated image.
468
+ """
469
+ # for k, box in enumerate(bboxes):
470
+ # im = cv2.imread("../destijl_dataset/00_preview/"+path_idx+".png")
471
+ # # [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
472
+ # xmin = box[0][0]
473
+ # xmax = box[1][0]
474
+ # ymin = box[0][1]
475
+ # ymax = box[2][1]
476
+ # print(xmin, xmax, ymin, ymax)
477
+ # cv2.rectangle(im,(xmin, ymin),(xmax, ymax),(255,0,0),2)
478
+ # cv2.imwrite("check_bboxes"+str(k)+".jpg", im)
479
+ all_bboxes[layer] = bboxes
480
+ all_images[layer] = img_path
481
+
482
+
483
+ # Design graph constructs the graph object and saves it as a pt file.
484
+ design_graph = DesignGraph(self.pretrained_model, all_images, all_bboxes, self.layers, img_path_dict['preview'], idx)
485
+ return design_graph.get_all_colors_in_design()
486
+
487
+ def trial(self):
488
+ # Save the samples in the range (0, n)
489
+ for i in range(0, 3100):
490
+ print("Sample: ", i)
491
+ all_colors = self.process_dataset(i)
492
+ for nested_list in all_colors:
493
+ color = nested_list[0]
494
+ # Fix if has unnecessary extra dimension
495
+ if len(color.shape) == 2:
496
+ color = color[0]
497
+ color = color.tolist()
498
+
499
+ if __name__ == "__main__":
500
+ config = DataConfig()
501
+ dataset_root = config.dataset
502
+ dataset = ProcessedDeStijl(data_path=dataset_root)
503
+ dataset.trial()
504
+
505
+
color_palette/deneme.png ADDED
color_palette/deneme.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from dataset import *
4
+ from torch_geometric.loader import DataLoader
5
+ from config import DataConfig
6
+
7
+ config = DataConfig()
8
+ model_name = config.model_name
9
+
10
+ test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
11
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
12
+ num_of_plots = len(test_loader)
13
+
14
+ def my_palplot(pal, size=1, ax=None):
15
+ """Plot the values in a color palette as a horizontal array.
16
+ Parameters
17
+ ----------
18
+ pal : sequence of matplotlib colors
19
+ colors, i.e. as returned by seaborn.color_palette()
20
+ size :
21
+ scaling factor for size of plot
22
+ ax :
23
+ an existing axes to use
24
+ """
25
+
26
+ import numpy as np
27
+ import matplotlib as mpl
28
+ import matplotlib.pyplot as plt
29
+ import matplotlib.ticker as ticker
30
+
31
+ n = len(pal)
32
+ if pal[0][0] > 1:
33
+ pal = np.array(pal) / 255
34
+ ax.imshow(np.arange(n).reshape(1, n),
35
+ cmap=mpl.colors.ListedColormap(list(pal)),
36
+ interpolation="nearest", aspect="auto")
37
+ ax.set_xticks(np.arange(n) - .5)
38
+ ax.set_yticks([-.5, .5])
39
+ # Ensure nice border between colors
40
+ ax.set_xticklabels(["" for _ in range(n)])
41
+ # The proper way to set no ticks
42
+ ax.yaxis.set_major_locator(ticker.NullLocator())
43
+
44
+ targets = np.load("targets.npy", allow_pickle=True)
45
+ preds = np.load("preds.npy", allow_pickle=True)
46
+ top_k_preds = np.squeeze(np.load("top_k_preds.npy", allow_pickle=True))
47
+ node_to_masks = np.load("node_to_mask.npy", allow_pickle=True)
48
+
49
+ rows = num_of_plots//3 + 1
50
+ cols = 3
51
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
52
+
53
+ column_titles = [" Prediction | Target " for i in range(cols)]
54
+ for ax, col in zip(ax_array[0], column_titles):
55
+ ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
56
+
57
+ palettes = []
58
+ for i in range(targets.shape[0]):
59
+ target = targets[i]
60
+ pred = preds[i]
61
+ top_k_pred = top_k_preds[i]
62
+
63
+ print(target.shape, pred.shape, top_k_pred.shape)
64
+
65
+ palette = np.concatenate((top_k_pred[1:], np.atleast_1d(pred), np.atleast_1d(target)))
66
+ palettes.append(palette)
67
+ ax = plt.subplot(rows, cols, i+1)
68
+
69
+ rgb_palette = []
70
+ for color in palette:
71
+ rgb_value = test_dataset.cube2rgb(color)
72
+ rgb_palette.append(rgb_value)
73
+
74
+ my_palplot(rgb_palette, ax=ax)
75
+
76
+ if i == num_of_plots - 1:
77
+ print("saviing")
78
+ plt.savefig("../models/"+model_name+"/top_k.jpg")
79
+
80
+
81
+ unique_values = np.unique(node_to_masks)
82
+ palettes = np.array(palettes)
83
+ each_node_pred = []
84
+ each_node_target = []
85
+ for value in unique_values:
86
+ indices = np.where(node_to_masks == value)[0]
87
+ each_node_pred.append(palettes[indices])
88
+ each_node_target.append(targets[indices])
89
+
90
+ targets_repeated = np.repeat(np.expand_dims(targets, axis=1), top_k_preds.shape[1], axis=1)
91
+ total = np.sum(palettes[:, :-1] == targets_repeated)
92
+ print("Total accuracy: ", total/600)
93
+
94
+ for j, value in enumerate(unique_values):
95
+
96
+ targets_repeated = np.repeat(np.expand_dims(each_node_target[j], axis=1), top_k_preds.shape[1], axis=1)
97
+ total = np.sum(each_node_pred[j][:, :-1] == targets_repeated)
98
+ print("Accuracy of node " + str(value) + ": ", total)
99
+
100
+ plt.close()
101
+ plt.hist(preds, bins=64)
102
+ plt.savefig("hist.jpg")
103
+
104
+ plt.close()
105
+ plt.hist(targets, bins=64)
106
+ plt.savefig("hist_gt.jpg")
107
+
color_palette/denemeler.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
color_palette/dist.png ADDED
color_palette/evaluate.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.loader import DataLoader
4
+
5
+ from dataset import GraphDestijlDataset
6
+ import yaml
7
+ import argparse
8
+
9
+ import seaborn as sns
10
+ import matplotlib.pyplot as plt
11
+ from utils import *
12
+ from config import *
13
+
14
+ ######################## Set Parameters ########################
15
+
16
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
17
+
18
+ # you can specify the config file you want to provide
19
+ # parser = argparse.ArgumentParser()
20
+ # parser.add_argument("--config_file", type=str, default="config/conf.yaml", help="Path to the config file.")
21
+ # args = parser.parse_args()
22
+ # config_file = args.config_file
23
+
24
+ # with open(config_file, 'r') as f:
25
+ # config = yaml.load(f, Loader=yaml.FullLoader)
26
+
27
+ config = DataConfig()
28
+
29
+ data_type = config.data_type
30
+ model_name = config.model_name
31
+ device = config.device
32
+ feature_size = config.feature_size
33
+ loss_function = config.loss_function
34
+ dataset_root = config.dataset
35
+ our_node_to_mask = config.node_to_mask
36
+
37
+ model_weight_path = "../models/" + model_name + "/weights/best.pth"
38
+
39
+ ######################## Model ########################
40
+
41
+ # Prepare dataset
42
+ # Set test=True for testing on the test set. Otherwise it tests on the train set.
43
+ test_dataset = GraphDestijlDataset(root=dataset_root, test=True)
44
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
45
+ num_of_plots = len(test_loader)
46
+
47
+ # Model is selected according to name you provide
48
+ model = model_switch(model_name, feature_size).to(device)
49
+ model.load_state_dict(torch.load(model_weight_path)["state_dict"])
50
+ criterion = nn.MSELoss()
51
+ ######################## Helper Functions ########################
52
+
53
+ # Skip black flag was used to eliminate examples that includes black colors.
54
+ # Always set to False for default experiments
55
+
56
+ skip_black_flag=False
57
+ def test(data, target_color, node_to_mask):
58
+
59
+ model.eval()
60
+ out = model(data.x, data.edge_index.long(), data.edge_attr)
61
+
62
+ if skip_black_flag:
63
+ for color in data.y:
64
+ a, b, c = color
65
+ if 0 <= a <= 30 and 0 <= b <= 30 and 0 <= c <= 30:
66
+ return None, None
67
+ else:
68
+ loss = colormath_CIE2000(out[node_to_mask, :][0], target_color[0])
69
+ else:
70
+ # Loss is only in MSE now.
71
+ loss = criterion(out[node_to_mask, :][0], target_color[0]/255)
72
+ return loss, out
73
+
74
+ def my_palplot(pal, size=1, ax=None):
75
+ """Plot the values in a color palette as a horizontal array.
76
+ Parameters
77
+ ----------
78
+ pal : sequence of matplotlib colors
79
+ colors, i.e. as returned by seaborn.color_palette()
80
+ size :
81
+ scaling factor for size of plot
82
+ ax :
83
+ an existing axes to use
84
+ """
85
+
86
+ import numpy as np
87
+ import matplotlib as mpl
88
+ import matplotlib.pyplot as plt
89
+ import matplotlib.ticker as ticker
90
+
91
+ n = len(pal)
92
+ if ax is None:
93
+ f, ax = plt.subplots(1, 1, figsize=(n * size, size))
94
+ ax.imshow(np.arange(n).reshape(1, n),
95
+ cmap=mpl.colors.ListedColormap(list(pal)),
96
+ interpolation="nearest", aspect="auto")
97
+ ax.set_xticks(np.arange(n) - .5)
98
+ ax.set_yticks([-.5, .5])
99
+ # Ensure nice border between colors
100
+ ax.set_xticklabels(["" for _ in range(n)])
101
+ # The proper way to set no ticks
102
+ ax.yaxis.set_major_locator(ticker.NullLocator())
103
+
104
+ # Config for plot
105
+ rows = num_of_plots//3 + 1
106
+ cols = 3
107
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
108
+
109
+ column_titles = [" Prediction | Target " for i in range(cols)]
110
+ for ax, col in zip(ax_array[0], column_titles):
111
+ ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
112
+
113
+ fig.suptitle(model_name+" Test Palettes", fontsize=100)
114
+
115
+ # Code for evaluation loop
116
+ plot_count = 0
117
+ val_losses = []
118
+ palettes = []
119
+ preds = []
120
+ targets = []
121
+ count = 0
122
+ for i, (input_data, target_color, node_to_mask) in enumerate(test_loader):
123
+ loss, out = test(input_data.to(device), target_color.to(device), node_to_mask)
124
+ if loss != None:
125
+ val_losses.append(loss.item())
126
+
127
+ # Get predicton and other colors in the palette
128
+ ax = plt.subplot(rows, cols, plot_count+1)
129
+
130
+ # Get prediction for a masked node
131
+ prediction = out[node_to_mask, :]
132
+ print("which node: ", node_to_mask, "count: ", count)
133
+ preds.append(prediction.detach().cpu()[0])
134
+ targets.append(target_color.detach().cpu()[0])
135
+ # Concat unmasked colors with prediction and ground truth
136
+ other_colors = input_data.y.clone()
137
+ other_colors = torch.cat([other_colors[0:node_to_mask, :], other_colors[node_to_mask+1:, :]])
138
+ #print(other_colors[0:node_to_mask, :].shape, other_colors[node_to_mask+1:, :].shape, node_to_mask)
139
+ other_colors = other_colors.type(torch.float32).detach().cpu().numpy()/255
140
+ # Normalize since they are in (0, 255) range.
141
+ # other_colors /= 255
142
+
143
+ if loss_function == "CIELab":
144
+ palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
145
+ else:
146
+ # Concat palettes. All of them are between (0, 1)
147
+ palette = np.clip(np.concatenate([other_colors, prediction.detach().cpu().numpy(), target_color.detach().cpu().numpy()]), a_min=0, a_max=1)
148
+ # I commented out codes related to calculating results in CIELab
149
+
150
+ # if "embedding" in model_name.lower():
151
+ # other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
152
+ # other_colors /= 255
153
+ # palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
154
+ # else:
155
+ # current_palette = torch.cat([other_colors, prediction, target_color.to(device)]).type(torch.float32).detach().cpu().numpy()
156
+ # palette = CIELab2RGB(current_palette)
157
+
158
+ # Save all the palettes to use it for distribution histograms.
159
+ palettes.append(prediction.detach().tolist()[0])
160
+ my_palplot(palette, ax=ax)
161
+ else:
162
+ print("none")
163
+
164
+ plot_count+=1
165
+ print(plot_count)
166
+ if i == num_of_plots-1:
167
+ path = "../models/"+model_name
168
+ if not os.path.exists(path):
169
+ os.mkdir(path)
170
+
171
+ if our_node_to_mask == -1:
172
+ print("hello")
173
+ plt.savefig(path+"/palettes.jpg")
174
+ else:
175
+ print("why red")
176
+ plt.savefig(path+"/palettes_only_blue.jpg")
177
+ plt.close()
178
+
179
+ # This is for checking prediction distribution
180
+ # It is saved as a histogram.
181
+ #check_distributions(palettes)
182
+
183
+ criterion = nn.MSELoss()
184
+ stacked_pred = torch.stack(preds)
185
+ stacked_target = torch.stack(targets)
186
+ random_results = np.random.random(size=(stacked_pred.shape[0], stacked_pred.shape[1]))
187
+ print(criterion(stacked_pred, stacked_target))
188
+ print(criterion(torch.Tensor(random_results), stacked_target))
color_palette/evaluate_CNN.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+
5
+ import yaml
6
+ import argparse
7
+ from torchvision import transforms
8
+
9
+ import seaborn as sns
10
+ import matplotlib.pyplot as plt
11
+ from utils import *
12
+ from model.CNN import *
13
+ from cnn_dataset import *
14
+
15
+ import pandas as pd
16
+
17
+ ######################## Set Parameters ########################
18
+
19
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
20
+
21
+ device = "cuda:1"
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--model_name", type=str, default="ColorCNN", help="Give the name of the model.")
25
+ args = parser.parse_args()
26
+ model_name = args.model_name
27
+
28
+ config_file = "../CNN_models/"+model_name+"/conf.yaml"
29
+
30
+ with open(config_file, 'r') as f:
31
+ config = yaml.load(f, Loader=yaml.FullLoader)
32
+
33
+ ## Basic Training Parameters ##
34
+ model_name = config["model_name"]
35
+ device = config["device"]
36
+
37
+ ## Neural Network Parameters ##
38
+ loss_function = config["loss_function"]
39
+ out_features = config["out_features"]
40
+ color_space = config["color_space"]
41
+ input_color_space = config["input_color_space"]
42
+ is_classification = config["is_classification"]
43
+ input_size = config["input_size"]
44
+ normalize_rgb = config["normalize_rgb"]
45
+ normalize_cielab = config["normalize_cielab"]
46
+
47
+ model_weight_path = "../CNN_models/" + model_name + "/weights/best.pth"
48
+
49
+ if out_features == 1:
50
+ out_type = "Lightness"
51
+ else:
52
+ out_type = "Color"
53
+
54
+ print("Evaluating for the model: ", model_name, "\n",
55
+ "Loss function: ", loss_function, "\n",
56
+ "Output Color Space: ", color_space, "\n",
57
+ "Color or Lightness?: ", out_type, "\n",
58
+ "Device: ", device, "\n")
59
+ ######################## Model ########################
60
+
61
+ transform = transforms.Compose([
62
+ transforms.Resize((input_size, input_size)),
63
+ #transforms.Normalize((255,), (255,))
64
+ ])
65
+
66
+ test_dataset = PreviewDataset(transform=transform,
67
+ color_space=color_space,
68
+ input_color_space=input_color_space,
69
+ normalize_rgb=normalize_rgb,
70
+ normalize_cielab=normalize_cielab,
71
+ test=True)
72
+
73
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
74
+ num_of_plots = len(test_loader)
75
+
76
+ model = model_switch_CNN(model_name, out_features).to(device)
77
+ model.load_state_dict(torch.load(model_weight_path)["state_dict"])
78
+
79
+ if loss_function == "MSE":
80
+ criterion = nn.MSELoss()
81
+ elif loss_function == "Cross-Entropy":
82
+ criterion = nn.CrossEntropyLoss()
83
+ elif loss_function == "MAE":
84
+ criterion = nn.L1Loss()
85
+
86
+ ######################## Helper Functions ########################
87
+ def test(data, color):
88
+ model.eval()
89
+ out = model(data)
90
+
91
+ if out_features == 1:
92
+ if loss_function != "CIELab":
93
+ loss = criterion(out[0][0], color[0][0])
94
+ else:
95
+ loss = colormath_CIE2000(out[0][0], color[0][0])
96
+ else:
97
+ if loss_function != "CIELab":
98
+ loss = criterion(out, color)
99
+ else:
100
+ loss = colormath_CIE2000(out, color)
101
+ return loss, out
102
+
103
+
104
+ def my_palplot(pal, size=1, ax=None):
105
+ """Plot the values in a color palette as a horizontal array.
106
+ Parameters
107
+ ----------
108
+ pal : sequence of matplotlib colors
109
+ colors, i.e. as returned by seaborn.color_palette()
110
+ size :
111
+ scaling factor for size of plot
112
+ ax :
113
+ an existing axes to use
114
+ """
115
+
116
+ import numpy as np
117
+ import matplotlib as mpl
118
+ import matplotlib.pyplot as plt
119
+ import matplotlib.ticker as ticker
120
+
121
+ n = len(pal)
122
+ if ax is None:
123
+ f, ax = plt.subplots(1, 1, figsize=(n * size, size))
124
+ ax.imshow(np.arange(n).reshape(1, n),
125
+ cmap=mpl.colors.ListedColormap(list(pal)),
126
+ interpolation="nearest", aspect="auto")
127
+ ax.set_xticks(np.arange(n) - .5)
128
+ ax.set_yticks([-.5, .5])
129
+ # Ensure nice border between colors
130
+ ax.set_xticklabels(["" for _ in range(n)])
131
+ # The proper way to set no ticks
132
+ ax.yaxis.set_major_locator(ticker.NullLocator())
133
+
134
+ rows = num_of_plots//3 + 1
135
+ cols = 3
136
+
137
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
138
+ column_titles = ["Prediction Target" for i in range(cols)]
139
+ for ax, col in zip(ax_array[0], column_titles):
140
+ ax.set_title(col, fontdict={'fontsize': 45, 'fontweight': 'medium'})
141
+
142
+ fig.suptitle(model_name+" Test Palettes", fontsize=100)
143
+
144
+ plot_count = 0
145
+ val_losses = []
146
+
147
+ outputs = []
148
+ target_colors = []
149
+
150
+ for i, (input_data, target_color) in enumerate(test_loader):
151
+ loss, out = test(input_data.to(device), target_color.to(device))
152
+ val_losses.append(loss.item())
153
+ # Get predicton and other colors in the palette
154
+ ax = plt.subplot(rows, cols, plot_count+1)
155
+
156
+ if color_space == "CIELab":
157
+ out = out.detach().cpu().numpy()
158
+ out = np.append(out, [[30.0, 30.0]], axis=1)
159
+ target_color = np.array([[target_color.detach().cpu().numpy()[0][0], 30.0, 30.0]])
160
+ palette = np.clip(np.concatenate([CIELab2RGB(out), CIELab2RGB(target_color)]), a_min=0, a_max=1)
161
+ else:
162
+ palette = np.clip(np.concatenate([out.detach().cpu().numpy(), target_color/255]), a_min=0, a_max=1)
163
+ outputs.append(out)
164
+ target_colors.append(target_color)
165
+
166
+ my_palplot(palette, ax=ax)
167
+
168
+ plot_count+=1
169
+
170
+ if i == num_of_plots-1:
171
+ path = "../CNN_models/"+model_name
172
+ if not os.path.exists(path):
173
+ os.mkdir(path)
174
+ plt.savefig(path+"/palettes.jpg")
175
+ plt.close()
176
+
177
+ cielab_dict = {'Output': outputs, 'Targets': target_colors}
178
+ df = pd.DataFrame(data=cielab_dict)
179
+
180
+ #df.to_csv("trainset_predictions.csv")
color_palette/evaluate_classification.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.loader import DataLoader
4
+
5
+ from dataset import GraphDestijlDataset
6
+ import yaml
7
+ import argparse
8
+
9
+ import seaborn as sns
10
+ import matplotlib.pyplot as plt
11
+ from utils import *
12
+ from config import *
13
+
14
+ from model.GNN import *
15
+
16
+ ######################## Set Parameters ########################
17
+
18
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
19
+
20
+ # you can specify the config file you want to provide
21
+ # parser = argparse.ArgumentParser()
22
+ # parser.add_argument("--config_file", type=str, default="config/conf.yaml", help="Path to the config file.")
23
+ # args = parser.parse_args()
24
+ # config_file = args.config_file
25
+
26
+ # with open(config_file, 'r') as f:
27
+ # config = yaml.load(f, Loader=yaml.FullLoader)
28
+
29
+ config = DataConfig()
30
+
31
+ data_type = config.data_type
32
+ model_name = config.model_name
33
+ device = config.device
34
+ feature_size = config.feature_size
35
+ loss_function = config.loss_function
36
+ dataset_root = config.dataset
37
+ our_node_to_mask = config.node_to_mask
38
+
39
+ model_weight_path = "../models/" + model_name + "/weights/best.pth"
40
+
41
+ ######################## Model ########################
42
+
43
+ # Prepare dataset
44
+ # Set test=True for testing on the test set. Otherwise it tests on the train set.
45
+ test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
46
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
47
+ num_of_plots = len(test_loader)
48
+
49
+ # Model is selected according to name you provide
50
+ model = ColorAttentionClassification(feature_size).to(device)
51
+ model.load_state_dict(torch.load(model_weight_path)["state_dict"])
52
+ criterion = nn.CrossEntropyLoss()
53
+ ######################## Helper Functions ########################
54
+
55
+ # Skip black flag was used to eliminate examples that includes black colors.
56
+ # Always set to False for default experiments
57
+
58
+ skip_black_flag=False
59
+ def test(data, target_color, node_to_mask):
60
+
61
+ model.eval()
62
+ out = model(data.x, data.edge_index.long(), data.edge_attr)
63
+ loss = criterion(out[node_to_mask, :], target_color)
64
+ return loss, out
65
+
66
+ def my_palplot(pal, pred, tar, size=1, ax=None):
67
+ """Plot the values in a color palette as a horizontal array.
68
+ Parameters
69
+ ----------
70
+ pal : sequence of matplotlib colors
71
+ colors, i.e. as returned by seaborn.color_palette()
72
+ size :
73
+ scaling factor for size of plot
74
+ ax :
75
+ an existing axes to use
76
+ """
77
+
78
+ import numpy as np
79
+ import matplotlib as mpl
80
+ import matplotlib.pyplot as plt
81
+ import matplotlib.ticker as ticker
82
+
83
+ n = len(pal)
84
+ if pal[0][0] > 1:
85
+ pal = np.array(pal) / 255
86
+ if ax is None:
87
+ f, ax = plt.subplots(1, 1, figsize=(n * size, size))
88
+ ax.imshow(np.arange(n).reshape(1, n),
89
+ cmap=mpl.colors.ListedColormap(list(pal)),
90
+ interpolation="nearest", aspect="auto")
91
+ ax.set_xticks(np.arange(n) - .5)
92
+ ax.set_yticks([-.5, .5])
93
+ rgb_pred = test_dataset.cube2rgb(pred)
94
+ rgb_pred_str = " ".join(str(x) for x in rgb_pred)
95
+ rgb_tar = test_dataset.cube2rgb(tar)
96
+ rgb_tar_str = " ".join(str(x) for x in rgb_tar)
97
+ ax.set_ylabel("Pred: " + rgb_pred_str + " Tar: " + rgb_tar_str, rotation=0, labelpad=100)
98
+ # Ensure nice border between colors
99
+ ax.set_xticklabels(["" for _ in range(n)])
100
+ # The proper way to set no ticks
101
+ ax.yaxis.set_major_locator(ticker.NullLocator())
102
+
103
+ # Config for plot
104
+ rows = num_of_plots//3 + 1
105
+ cols = 3
106
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
107
+
108
+
109
+ column_titles = [" Prediction | Target " for i in range(cols)]
110
+ for ax, col in zip(ax_array[0], column_titles):
111
+ ax.set_title(col, fontdict={'fontsize': 30, 'fontweight': 'medium'})
112
+
113
+ fig.suptitle(model_name+" Test Palettes", fontsize=100)
114
+
115
+ # Code for evaluation loop
116
+ plot_count = 0
117
+ val_losses = []
118
+ palettes = []
119
+ preds = []
120
+ targets = []
121
+ count = 0
122
+ top_k_all_preds = []
123
+ node_names = []
124
+
125
+ for i, (input_data, target_color, node_to_mask) in enumerate(test_loader):
126
+ loss, out = test(input_data.to(device), target_color.to(device), node_to_mask)
127
+ if loss != None:
128
+ val_losses.append(loss.item())
129
+
130
+ # Get predicton and other colors in the palette
131
+ ax = plt.subplot(rows, cols, plot_count+1)
132
+
133
+ # Get prediction for a masked node
134
+
135
+ prediction = out[node_to_mask, :]
136
+ #preds.append(prediction.detach().cpu()[0])
137
+ #targets.append(target_color.detach().cpu()[0])
138
+ # Concat unmasked colors with prediction and ground truth
139
+ other_colors = torch.tensor(input_data.y)
140
+ # FIX THIS
141
+ other_colors = torch.cat((other_colors[0][0:node_to_mask], other_colors[0][node_to_mask+1:]))
142
+ #print(other_colors[0:node_to_mask, :].shape, other_colors[node_to_mask+1:, :].shape, node_to_mask)
143
+ other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
144
+ # Normalize since they are in (0, 255) range.
145
+ # other_colors /= 255
146
+
147
+ # Concat palettes. All of them are between (0, 1)
148
+ a, top_k_preds = F.softmax(torch.tensor(prediction)).topk(5)
149
+ top_k_all_preds.append(top_k_preds.detach().numpy())
150
+ node_names.append(node_to_mask.detach().numpy())
151
+ prediction = torch.argmax(F.softmax(torch.tensor(prediction)), dim=1).numpy()
152
+ print("Pred: ", prediction, " Target: ", target_color)
153
+ preds.append(prediction.item())
154
+ print("pred cube to color: ", test_dataset.cube2rgb(prediction.item()))
155
+ print("target cube to color: ", test_dataset.cube2rgb(target_color.item()))
156
+ targets.append(target_color.item())
157
+ # print(other_colors, prediction, target_color)
158
+ # print(other_colors[0])
159
+ # print(np.atleast_1d(prediction), np.atleast_1d(target_color.detach().cpu().numpy()))
160
+ palette = np.concatenate((other_colors, np.atleast_1d(prediction.item()), np.atleast_1d(target_color.item())))
161
+ pred_palette = np.concatenate((top_k_preds[0], np.atleast_1d(target_color.item())))
162
+ # I commented out codes related to calculating results in CIELab
163
+ # if "embedding" in model_name.lower():
164
+ # other_colors = other_colors.type(torch.float32).detach().cpu().numpy()
165
+ # other_colors /= 255
166
+ # palette = np.clip(np.concatenate([other_colors, CIELab2RGB(prediction), CIELab2RGB(target_color[0])]), a_min=0, a_max=1)
167
+ # else:
168
+ # current_palette = torch.cat([other_colors, prediction, target_color.to(device)]).type(torch.float32).detach().cpu().numpy()
169
+ # palette = CIELab2RGB(current_palette)
170
+
171
+ # Save all the palettes to use it for distribution histograms.
172
+
173
+ all_colors = []
174
+ for e, num in enumerate(palette):
175
+ color = test_dataset.cube2rgb(num)
176
+ all_colors.append(color)
177
+
178
+ #palettes.append(prediction.detach().tolist()[0])
179
+ my_palplot(all_colors, prediction.item(), target_color.item(), ax=ax)
180
+ else:
181
+ print("none")
182
+
183
+ plot_count+=1
184
+ print(plot_count)
185
+ if i == num_of_plots-1:
186
+ path = "../models/"+model_name
187
+ if not os.path.exists(path):
188
+ os.mkdir(path)
189
+
190
+ if our_node_to_mask == -1:
191
+ print("hello")
192
+ plt.savefig(path+"/palettes.jpg")
193
+
194
+ else:
195
+ print("why red")
196
+ plt.savefig(path+"/palettes_only_red.jpg")
197
+
198
+ N = len(test_dataset)
199
+ print("Evaluation Accuracy: ", np.sum(np.array(preds) == np.array(targets))/N)
200
+ random_results = (np.random.random(size=(N,))*64).astype(np.uint16)
201
+ print("Evaluation Accuracy Random: ", np.sum(random_results == np.array(targets))/N)
202
+
203
+ np.save("preds.npy", np.array(preds))
204
+ np.save("targets.npy", np.array(targets))
205
+ np.save("node_to_mask.npy", np.array(node_names))
206
+ np.save("top_k_preds.npy", np.array(top_k_all_preds))
207
+
208
+ # This is for checking prediction distribution
209
+ # It is saved as a histogram.
210
+ #check_distributions(palettes)
211
+
212
+ # criterion = nn.MSELoss()
213
+ # stacked_pred = torch.stack(preds)
214
+ # stacked_target = torch.stack(targets)
215
+ # random_results = np.random.random(size=(stacked_pred.shape[0], stacked_pred.shape[1]))
216
+ # print(criterion(stacked_pred, stacked_target))
217
+ # print(criterion(torch.Tensor(random_results), stacked_target))
color_palette/evaluate_recommend.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib as mpl
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as ticker
5
+ from regressor.config import config_to_use
6
+
7
+ def my_palplot(pal, size=1, ax=None):
8
+
9
+ n = len(pal)
10
+ if ax is None:
11
+ f, ax = plt.subplots(1, 1, figsize=(n * size, size))
12
+ ax.imshow(np.arange(n).reshape(1, n),
13
+ cmap=mpl.colors.ListedColormap(list(pal)),
14
+ interpolation="nearest", aspect="auto")
15
+ ax.set_xticks(np.arange(n) - .5)
16
+ ax.set_yticks([-.5, .5])
17
+ # Ensure nice border between colors
18
+ ax.set_xticklabels(["" for _ in range(n)])
19
+ # The proper way to set no ticks
20
+ ax.yaxis.set_major_locator(ticker.NullLocator())
21
+
22
+ # test_preds = np.load('all_one_hot_LR/test_preds_graph.npy')
23
+ palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
24
+ original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
25
+
26
+ print("testing out stuff")
27
+
28
+ # test_preds = np.expand_dims(test_preds, axis=1)
29
+
30
+ # all_colors = np.concatenate((palettes, test_preds), axis=1)
31
+
32
+ colors = np.clip(palettes, a_min=0, a_max=1)
33
+ colors_org = np.clip(original_palettes, a_min=0, a_max=1)
34
+
35
+ rows = 50
36
+ cols = 2
37
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
38
+
39
+ column_titles = ["Updated palettes" for i in range(cols)]
40
+ for ax, col in zip(ax_array[0], column_titles):
41
+ ax.set_title(col, fontdict={'fontsize': 50, 'fontweight': 'medium'})
42
+
43
+ fig.suptitle("Test Palettes", fontsize=100)
44
+ plot_count = 0
45
+ for i in range(len(colors)):
46
+ ax = plt.subplot(rows, cols, plot_count+1)
47
+ my_palplot(colors[i], ax=ax)
48
+ plot_count += 1
49
+ if plot_count == 100:
50
+ break
51
+
52
+ plt.savefig(config_to_use.save_folder+'/recommended_purple.jpg')
53
+ plt.clf()
54
+
55
+ rows = 50
56
+ cols = 2
57
+ fig, ax_array = plt.subplots(rows, cols, figsize=(60, 60), dpi=80, squeeze=False)
58
+
59
+ column_titles = ["Original Palettes" for i in range(cols)]
60
+ for ax, col in zip(ax_array[0], column_titles):
61
+ ax.set_title(col, fontdict={'fontsize': 50, 'fontweight': 'medium'})
62
+
63
+ fig.suptitle("Test Palettes", fontsize=100)
64
+ plot_count = 0
65
+ for i in range(len(colors)):
66
+ ax = plt.subplot(rows, cols, plot_count+1)
67
+ my_palplot(colors_org[i], ax=ax)
68
+ plot_count += 1
69
+ if plot_count == 100:
70
+ break
71
+
72
+ plt.savefig(config_to_use.save_folder+'/original_purple.jpg')
color_palette/model/CNN.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
4
+
5
+ class Autoencoder(torch.nn.Module):
6
+ def __init__(self, num_channels=3, c_hid=16, latent_dim=256):
7
+ super().__init__()
8
+
9
+ # Building an linear encoder with Linear
10
+ # layer followed by Relu activation function
11
+ # 784 ==> 9
12
+ self.encoder = torch.nn.Sequential(
13
+ nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 256x256 => 128x128
14
+ nn.ReLU(),
15
+ nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
16
+ nn.ReLU(),
17
+ nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
18
+ nn.ReLU(),
19
+ nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
20
+ nn.ReLU(),
21
+ nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 64x64 => 32x32
22
+ nn.ReLU(),
23
+ nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
24
+ )
25
+
26
+ # Building an linear decoder with Linear
27
+ # layer followed by Relu activation function
28
+ # The Sigmoid activation function
29
+ # outputs the value between 0 and 1
30
+ # 9 ==> 784
31
+ self.linear = nn.Sequential(
32
+ nn.Flatten(), # Image grid to single feature vector
33
+ nn.Linear(16 * 16 * c_hid, num_channels*latent_dim),
34
+ )
35
+ self.decoder = nn.Sequential(
36
+ nn.ConvTranspose2d(
37
+ num_channels, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
38
+ ), # 4x4 => 8x8
39
+ nn.ReLU(),
40
+ nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
41
+ nn.ReLU(),
42
+ nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
43
+ nn.ReLU(),
44
+ nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
45
+ nn.ReLU(),
46
+ nn.ConvTranspose2d(
47
+ c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
48
+ ), # 16x16 => 32x32
49
+ nn.ReLU(),
50
+ nn.ConvTranspose2d(
51
+ c_hid, num_channels, kernel_size=3, output_padding=1, padding=1, stride=2
52
+ ), # 32x32 => 64x64
53
+ nn.Tanh(), # The input images is scaled between -1 and 1, hence the output has to be bounded as well
54
+ )
55
+
56
+ def forward(self, x):
57
+ encoded = self.encoder(x)
58
+ linear = self.linear(encoded)
59
+ x = linear.reshape(linear.shape[0], -1, 16, 16)
60
+ decoded = self.decoder(x)
61
+ return decoded
62
+
63
+ class FinetuneResNet18_classify(nn.Module):
64
+ def __init__(self, freeze_resnet=True):
65
+ super().__init__()
66
+
67
+ "Classify the color"
68
+
69
+ self.pretrained_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
70
+
71
+ for param in self.pretrained_model.parameters():
72
+ param.requires_grad_ = False
73
+
74
+ self.pretrained_model.fc = nn.Linear(in_features=512, out_features=1024)
75
+
76
+ self.color_head = nn.Sequential(
77
+ nn.Linear(in_features=1024, out_features=768),
78
+ )
79
+ self.softmax = nn.Softmax(dim=1)
80
+
81
+ def forward(self, x):
82
+ x = self.pretrained_model(x)
83
+ x = self.color_head(x)
84
+ r = self.softmax(x[:, :256])
85
+ g = self.softmax(x[:, 256:512])
86
+ b = self.softmax(x[:, 512:])
87
+ return r, g, b
88
+
89
+ class ResNet18(nn.Module):
90
+ def __init__(self, freeze_resnet=True, map_outputs="RGB"):
91
+ super().__init__()
92
+
93
+ """
94
+ Just map to interval
95
+ """
96
+ self.pretrained_model = resnet18(weights=None)
97
+ self.map_outputs = map_outputs
98
+
99
+ for param in self.pretrained_model.parameters():
100
+ param.requires_grad_ = True
101
+
102
+ self.pretrained_model.fc = nn.Linear(in_features=512, out_features=256)
103
+
104
+ self.color_head = nn.Sequential(
105
+ nn.Linear(in_features=256, out_features=128),
106
+ nn.Linear(in_features=128, out_features=64),
107
+ nn.Linear(in_features=64, out_features=3),
108
+ )
109
+
110
+ if self.map_outputs == "CIELab":
111
+ self.l_activation = nn.Sigmoid()
112
+ self.a_activation = nn.Tanh()
113
+ self.b_activation = nn.Tanh()
114
+
115
+ elif self.map_outputs == "RGB":
116
+ self.activation = nn.Sigmoid()
117
+
118
+ def forward(self, x):
119
+ x = self.pretrained_model(x)
120
+ x = self.color_head(x)
121
+ if self.map_outputs == "CIElab":
122
+ x[:, 0] = self.l_activation(x[:, 0]) * 100
123
+ x[:, 1] = self.a_activation(x[:, 1]) * 127
124
+ x[:, 2] = self.b_activation(x[:, 2]) * 127
125
+
126
+ elif self.map_outputs == "RGB":
127
+ x = x
128
+ return x
129
+
130
+
131
+ class ColorCNN(nn.Module):
132
+ def __init__(self, num_channels=3, c_hid=16, out_feature=3, reverse_normalize_output=True):
133
+ super().__init__()
134
+ self.encoder = torch.nn.Sequential(
135
+ nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 512x512 -> 256x256
136
+ nn.BatchNorm2d(num_features=c_hid),
137
+ nn.ReLU(),
138
+ nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
139
+ nn.ReLU(),
140
+ nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 256x256 -> 128x128
141
+ nn.BatchNorm2d(num_features=2*c_hid),
142
+ nn.ReLU(),
143
+ nn.Conv2d(2 * c_hid, 4 * c_hid, kernel_size=3, padding=1),
144
+ nn.ReLU(),
145
+ nn.Conv2d(4 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
146
+ nn.BatchNorm2d(num_features=2*c_hid),
147
+ nn.ReLU(),
148
+ nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 64x64 => 32x32
149
+ nn.BatchNorm2d(num_features=c_hid),
150
+ nn.ReLU(),
151
+ nn.Conv2d(c_hid, 8, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
152
+ nn.BatchNorm2d(num_features=8),
153
+ nn.ReLU(),
154
+ )
155
+
156
+ self.color_head = nn.Sequential(
157
+ nn.Flatten(),
158
+ nn.Linear(in_features=16*16*8, out_features=out_feature)
159
+ )
160
+
161
+ self.reverse_normalize_output = reverse_normalize_output
162
+ #self.addition = Addition()
163
+ #self.multiplication = torch.Tensor([100, 255, 255]).to("cuda:1")
164
+
165
+ def forward(self, x):
166
+ x = self.encoder(x)
167
+ x = self.color_head(x)
168
+ return x
169
+
170
+ class Addition(nn.Module):
171
+ def __init__(self) -> None:
172
+ super().__init__()
173
+
174
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
175
+ input += torch.Tensor([0, 127, 127]).to("cuda:1")
176
+ return input
177
+
178
+ class ColorCNNBigger(nn.Module):
179
+ def __init__(self, num_channels=3, c_hid=16):
180
+ super().__init__()
181
+ self.encoder = torch.nn.Sequential(
182
+ nn.Conv2d(num_channels, c_hid, kernel_size=3, padding=1, stride=2), # 512x512 -> 256x256
183
+ nn.BatchNorm2d(num_features=c_hid),
184
+ nn.ReLU(),
185
+ nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1),
186
+ nn.ReLU(),
187
+ nn.Conv2d(2 * c_hid, 4 * c_hid, kernel_size=3, padding=1, stride=2), # 256x256 -> 128x128
188
+ nn.BatchNorm2d(num_features=4*c_hid),
189
+ nn.ReLU(),
190
+ nn.Conv2d(4 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
191
+ nn.ReLU(),
192
+ nn.Conv2d(2 * c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
193
+ nn.BatchNorm2d(num_features=c_hid),
194
+ nn.ReLU(),
195
+ nn.Conv2d(c_hid, c_hid//2, kernel_size=3, padding=1, stride=2), # 128x128 => 64x64
196
+ nn.BatchNorm2d(num_features=c_hid//2),
197
+ )
198
+
199
+ self.color_head = nn.Sequential(
200
+ nn.Flatten(),
201
+ nn.Linear(in_features=32*32*8, out_features=16*16*4),
202
+ nn.Linear(in_features=16*16*4, out_features=3)
203
+ )
204
+
205
+ def forward(self, x):
206
+ x = self.encoder(x)
207
+ x = self.color_head(x)
208
+ return x
209
+