Spaces:
Runtime error
Runtime error
akhaliq3
commited on
Commit
·
035e10c
1
Parent(s):
80e980c
spaces demo
Browse files- LICENSE +21 -0
- inference/.DS_Store +0 -0
- inference/brush/brush_large_horizontal.png +0 -0
- inference/brush/brush_large_vertical.png +0 -0
- inference/brush/brush_small_horizontal.png +0 -0
- inference/brush/brush_small_vertical.png +0 -0
- inference/inference.py +496 -0
- inference/input/.DS_Store +0 -0
- inference/input/temp.txt +0 -0
- inference/morphology.py +51 -0
- inference/network.py +84 -0
- train/brush/brush_large_horizontal.png +0 -0
- train/brush/brush_large_vertical.png +0 -0
- train/brush/brush_small_horizontal.png +0 -0
- train/brush/brush_small_vertical.png +0 -0
- train/data/__init__.py +94 -0
- train/data/base_dataset.py +153 -0
- train/data/null_dataset.py +15 -0
- train/models/__init__.py +67 -0
- train/models/base_model.py +230 -0
- train/models/networks.py +143 -0
- train/models/painter_model.py +247 -0
- train/options/__init__.py +1 -0
- train/options/base_options.py +151 -0
- train/options/test_options.py +23 -0
- train/options/train_options.py +52 -0
- train/train.py +58 -0
- train/train.sh +14 -0
- train/util/__init__.py +1 -0
- train/util/html.py +86 -0
- train/util/morphology.py +43 -0
- train/util/util.py +103 -0
- train/util/visualizer.py +224 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Huage001
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
inference/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
inference/brush/brush_large_horizontal.png
ADDED
inference/brush/brush_large_vertical.png
ADDED
inference/brush/brush_small_horizontal.png
ADDED
inference/brush/brush_small_vertical.png
ADDED
inference/inference.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import network
|
6 |
+
import morphology
|
7 |
+
import os
|
8 |
+
import math
|
9 |
+
|
10 |
+
idx = 0
|
11 |
+
|
12 |
+
|
13 |
+
def save_img(img, output_path):
|
14 |
+
result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
|
15 |
+
result.save(output_path)
|
16 |
+
|
17 |
+
|
18 |
+
def param2stroke(param, H, W, meta_brushes):
|
19 |
+
"""
|
20 |
+
Input a set of stroke parameters and output its corresponding foregrounds and alpha maps.
|
21 |
+
Args:
|
22 |
+
param: a tensor with shape n_strokes x n_param_per_stroke. Here, param_per_stroke is 8:
|
23 |
+
x_center, y_center, width, height, theta, R, G, and B.
|
24 |
+
H: output height.
|
25 |
+
W: output width.
|
26 |
+
meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
|
27 |
+
The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
foregrounds: a tensor with shape n_strokes x 3 x H x W, containing color information.
|
31 |
+
alphas: a tensor with shape n_strokes x 3 x H x W,
|
32 |
+
containing binary information of whether a pixel is belonging to the stroke (alpha mat), for painting process.
|
33 |
+
"""
|
34 |
+
# Firstly, resize the meta brushes to the required shape,
|
35 |
+
# in order to decrease GPU memory especially when the required shape is small.
|
36 |
+
meta_brushes_resize = F.interpolate(meta_brushes, (H, W))
|
37 |
+
b = param.shape[0]
|
38 |
+
# Extract shape parameters and color parameters.
|
39 |
+
param_list = torch.split(param, 1, dim=1)
|
40 |
+
x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
|
41 |
+
R, G, B = param_list[5:]
|
42 |
+
# Pre-compute sin theta and cos theta
|
43 |
+
sin_theta = torch.sin(torch.acos(torch.tensor(-1., device=param.device)) * theta)
|
44 |
+
cos_theta = torch.cos(torch.acos(torch.tensor(-1., device=param.device)) * theta)
|
45 |
+
# index means each stroke should use which meta stroke? Vertical meta stroke or horizontal meta stroke.
|
46 |
+
# When h > w, vertical stroke should be used. When h <= w, horizontal stroke should be used.
|
47 |
+
index = torch.full((b,), -1, device=param.device, dtype=torch.long)
|
48 |
+
index[h > w] = 0
|
49 |
+
index[h <= w] = 1
|
50 |
+
brush = meta_brushes_resize[index.long()]
|
51 |
+
|
52 |
+
# Calculate warp matrix according to the rules defined by pytorch, in order for warping.
|
53 |
+
warp_00 = cos_theta / w
|
54 |
+
warp_01 = sin_theta * H / (W * w)
|
55 |
+
warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
|
56 |
+
warp_10 = -sin_theta * W / (H * h)
|
57 |
+
warp_11 = cos_theta / h
|
58 |
+
warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
|
59 |
+
warp_0 = torch.stack([warp_00, warp_01, warp_02], dim=1)
|
60 |
+
warp_1 = torch.stack([warp_10, warp_11, warp_12], dim=1)
|
61 |
+
warp = torch.stack([warp_0, warp_1], dim=1)
|
62 |
+
# Conduct warping.
|
63 |
+
grid = F.affine_grid(warp, [b, 3, H, W], align_corners=False)
|
64 |
+
brush = F.grid_sample(brush, grid, align_corners=False)
|
65 |
+
# alphas is the binary information suggesting whether a pixel is belonging to the stroke.
|
66 |
+
alphas = (brush > 0).float()
|
67 |
+
brush = brush.repeat(1, 3, 1, 1)
|
68 |
+
alphas = alphas.repeat(1, 3, 1, 1)
|
69 |
+
# Give color to foreground strokes.
|
70 |
+
color_map = torch.cat([R, G, B], dim=1)
|
71 |
+
color_map = color_map.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W)
|
72 |
+
foreground = brush * color_map
|
73 |
+
# Dilation and erosion are used for foregrounds and alphas respectively to prevent artifacts on stroke borders.
|
74 |
+
foreground = morphology.dilation(foreground)
|
75 |
+
alphas = morphology.erosion(alphas)
|
76 |
+
return foreground, alphas
|
77 |
+
|
78 |
+
|
79 |
+
def param2img_serial(
|
80 |
+
param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None):
|
81 |
+
"""
|
82 |
+
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
|
83 |
+
and whether there is a border (if intermediate painting results are required).
|
84 |
+
Output the painting results of adding the corresponding strokes on the current canvas.
|
85 |
+
Args:
|
86 |
+
param: a tensor with shape batch size x patch along height dimension x patch along width dimension
|
87 |
+
x n_stroke_per_patch x n_param_per_stroke
|
88 |
+
decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
|
89 |
+
x n_stroke_per_patch
|
90 |
+
meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
|
91 |
+
The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
|
92 |
+
cur_canvas: a tensor with shape batch size x 3 x H x W,
|
93 |
+
where H and W denote height and width of padded results of original images.
|
94 |
+
frame_dir: directory to save intermediate painting results. None means intermediate results are not required.
|
95 |
+
has_border: on the last painting layer, in order to make sure that the painting results do not miss
|
96 |
+
any important detail, we choose to paint again on this layer but shift patch_size // 2 pixels when
|
97 |
+
cutting patches. In this case, if intermediate results are required, we need to cut the shifted length
|
98 |
+
on the border before saving, or there would be a black border.
|
99 |
+
original_h: to indicate the original height for cropping when saving intermediate results.
|
100 |
+
original_w: to indicate the original width for cropping when saving intermediate results.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
|
104 |
+
"""
|
105 |
+
# param: b, h, w, stroke_per_patch, param_per_stroke
|
106 |
+
# decision: b, h, w, stroke_per_patch
|
107 |
+
b, h, w, s, p = param.shape
|
108 |
+
H, W = cur_canvas.shape[-2:]
|
109 |
+
is_odd_y = h % 2 == 1
|
110 |
+
is_odd_x = w % 2 == 1
|
111 |
+
patch_size_y = 2 * H // h
|
112 |
+
patch_size_x = 2 * W // w
|
113 |
+
even_idx_y = torch.arange(0, h, 2, device=cur_canvas.device)
|
114 |
+
even_idx_x = torch.arange(0, w, 2, device=cur_canvas.device)
|
115 |
+
odd_idx_y = torch.arange(1, h, 2, device=cur_canvas.device)
|
116 |
+
odd_idx_x = torch.arange(1, w, 2, device=cur_canvas.device)
|
117 |
+
even_y_even_x_coord_y, even_y_even_x_coord_x = torch.meshgrid([even_idx_y, even_idx_x])
|
118 |
+
odd_y_odd_x_coord_y, odd_y_odd_x_coord_x = torch.meshgrid([odd_idx_y, odd_idx_x])
|
119 |
+
even_y_odd_x_coord_y, even_y_odd_x_coord_x = torch.meshgrid([even_idx_y, odd_idx_x])
|
120 |
+
odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
|
121 |
+
cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
|
122 |
+
patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
|
123 |
+
|
124 |
+
def partial_render(this_canvas, patch_coord_y, patch_coord_x, stroke_id):
|
125 |
+
canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
|
126 |
+
stride=(patch_size_y // 2, patch_size_x // 2))
|
127 |
+
# canvas_patch: b, 3 * py * px, h * w
|
128 |
+
canvas_patch = canvas_patch.view(b, 3, patch_size_y, patch_size_x, h, w).contiguous()
|
129 |
+
canvas_patch = canvas_patch.permute(0, 4, 5, 1, 2, 3).contiguous()
|
130 |
+
# canvas_patch: b, h, w, 3, py, px
|
131 |
+
selected_canvas_patch = canvas_patch[:, patch_coord_y, patch_coord_x, :, :, :]
|
132 |
+
selected_h, selected_w = selected_canvas_patch.shape[1:3]
|
133 |
+
selected_param = param[:, patch_coord_y, patch_coord_x, stroke_id, :].view(-1, p).contiguous()
|
134 |
+
selected_decision = decision[:, patch_coord_y, patch_coord_x, stroke_id].view(-1).contiguous()
|
135 |
+
selected_foregrounds = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x,
|
136 |
+
device=this_canvas.device)
|
137 |
+
selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
|
138 |
+
if selected_param[selected_decision, :].shape[0] > 0:
|
139 |
+
selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] = \
|
140 |
+
param2stroke(selected_param[selected_decision, :], patch_size_y, patch_size_x, meta_brushes)
|
141 |
+
selected_foregrounds = selected_foregrounds.view(
|
142 |
+
b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
143 |
+
selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
144 |
+
selected_decision = selected_decision.view(b, selected_h, selected_w, 1, 1, 1).contiguous()
|
145 |
+
selected_canvas_patch = selected_foregrounds * selected_alphas * selected_decision + selected_canvas_patch * (
|
146 |
+
1 - selected_alphas * selected_decision)
|
147 |
+
this_canvas = selected_canvas_patch.permute(0, 3, 1, 4, 2, 5).contiguous()
|
148 |
+
# this_canvas: b, 3, selected_h, py, selected_w, px
|
149 |
+
this_canvas = this_canvas.view(b, 3, selected_h * patch_size_y, selected_w * patch_size_x).contiguous()
|
150 |
+
# this_canvas: b, 3, selected_h * py, selected_w * px
|
151 |
+
return this_canvas
|
152 |
+
|
153 |
+
global idx
|
154 |
+
if has_border:
|
155 |
+
factor = 2
|
156 |
+
else:
|
157 |
+
factor = 4
|
158 |
+
if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
159 |
+
for i in range(s):
|
160 |
+
canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
|
161 |
+
if not is_odd_y:
|
162 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
163 |
+
if not is_odd_x:
|
164 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
165 |
+
cur_canvas = canvas
|
166 |
+
idx += 1
|
167 |
+
if frame_dir is not None:
|
168 |
+
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
169 |
+
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
170 |
+
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
171 |
+
|
172 |
+
if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
173 |
+
for i in range(s):
|
174 |
+
canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x, i)
|
175 |
+
canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
|
176 |
+
canvas = torch.cat([cur_canvas[:, :, -canvas.shape[2]:, :patch_size_x // 2], canvas], dim=3)
|
177 |
+
if is_odd_y:
|
178 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
179 |
+
if is_odd_x:
|
180 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
181 |
+
cur_canvas = canvas
|
182 |
+
idx += 1
|
183 |
+
if frame_dir is not None:
|
184 |
+
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
185 |
+
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
186 |
+
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
187 |
+
|
188 |
+
if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
189 |
+
for i in range(s):
|
190 |
+
canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x, i)
|
191 |
+
canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
|
192 |
+
if is_odd_y:
|
193 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
194 |
+
if not is_odd_x:
|
195 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
196 |
+
cur_canvas = canvas
|
197 |
+
idx += 1
|
198 |
+
if frame_dir is not None:
|
199 |
+
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
200 |
+
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
201 |
+
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
202 |
+
|
203 |
+
if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
204 |
+
for i in range(s):
|
205 |
+
canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x, i)
|
206 |
+
canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
|
207 |
+
if not is_odd_y:
|
208 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, -canvas.shape[3]:]], dim=2)
|
209 |
+
if is_odd_x:
|
210 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
211 |
+
cur_canvas = canvas
|
212 |
+
idx += 1
|
213 |
+
if frame_dir is not None:
|
214 |
+
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
215 |
+
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
216 |
+
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
217 |
+
|
218 |
+
cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
|
219 |
+
|
220 |
+
return cur_canvas
|
221 |
+
|
222 |
+
|
223 |
+
def param2img_parallel(param, decision, meta_brushes, cur_canvas):
|
224 |
+
"""
|
225 |
+
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
|
226 |
+
and whether there is a border (if intermediate painting results are required).
|
227 |
+
Output the painting results of adding the corresponding strokes on the current canvas.
|
228 |
+
Args:
|
229 |
+
param: a tensor with shape batch size x patch along height dimension x patch along width dimension
|
230 |
+
x n_stroke_per_patch x n_param_per_stroke
|
231 |
+
decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
|
232 |
+
x n_stroke_per_patch
|
233 |
+
meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
|
234 |
+
The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
|
235 |
+
cur_canvas: a tensor with shape batch size x 3 x H x W,
|
236 |
+
where H and W denote height and width of padded results of original images.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
|
240 |
+
"""
|
241 |
+
# param: b, h, w, stroke_per_patch, param_per_stroke
|
242 |
+
# decision: b, h, w, stroke_per_patch
|
243 |
+
b, h, w, s, p = param.shape
|
244 |
+
param = param.view(-1, 8).contiguous()
|
245 |
+
decision = decision.view(-1).contiguous().bool()
|
246 |
+
H, W = cur_canvas.shape[-2:]
|
247 |
+
is_odd_y = h % 2 == 1
|
248 |
+
is_odd_x = w % 2 == 1
|
249 |
+
patch_size_y = 2 * H // h
|
250 |
+
patch_size_x = 2 * W // w
|
251 |
+
even_idx_y = torch.arange(0, h, 2, device=cur_canvas.device)
|
252 |
+
even_idx_x = torch.arange(0, w, 2, device=cur_canvas.device)
|
253 |
+
odd_idx_y = torch.arange(1, h, 2, device=cur_canvas.device)
|
254 |
+
odd_idx_x = torch.arange(1, w, 2, device=cur_canvas.device)
|
255 |
+
even_y_even_x_coord_y, even_y_even_x_coord_x = torch.meshgrid([even_idx_y, even_idx_x])
|
256 |
+
odd_y_odd_x_coord_y, odd_y_odd_x_coord_x = torch.meshgrid([odd_idx_y, odd_idx_x])
|
257 |
+
even_y_odd_x_coord_y, even_y_odd_x_coord_x = torch.meshgrid([even_idx_y, odd_idx_x])
|
258 |
+
odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
|
259 |
+
cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
|
260 |
+
patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
|
261 |
+
foregrounds = torch.zeros(param.shape[0], 3, patch_size_y, patch_size_x, device=cur_canvas.device)
|
262 |
+
alphas = torch.zeros(param.shape[0], 3, patch_size_y, patch_size_x, device=cur_canvas.device)
|
263 |
+
valid_foregrounds, valid_alphas = param2stroke(param[decision, :], patch_size_y, patch_size_x, meta_brushes)
|
264 |
+
foregrounds[decision, :, :, :] = valid_foregrounds
|
265 |
+
alphas[decision, :, :, :] = valid_alphas
|
266 |
+
# foreground, alpha: b * h * w * stroke_per_patch, 3, patch_size_y, patch_size_x
|
267 |
+
foregrounds = foregrounds.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
|
268 |
+
alphas = alphas.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
|
269 |
+
# foreground, alpha: b, h, w, stroke_per_patch, 3, render_size_y, render_size_x
|
270 |
+
decision = decision.view(-1, h, w, s, 1, 1, 1).contiguous()
|
271 |
+
|
272 |
+
# decision: b, h, w, stroke_per_patch, 1, 1, 1
|
273 |
+
|
274 |
+
def partial_render(this_canvas, patch_coord_y, patch_coord_x):
|
275 |
+
|
276 |
+
canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
|
277 |
+
stride=(patch_size_y // 2, patch_size_x // 2))
|
278 |
+
# canvas_patch: b, 3 * py * px, h * w
|
279 |
+
canvas_patch = canvas_patch.view(b, 3, patch_size_y, patch_size_x, h, w).contiguous()
|
280 |
+
canvas_patch = canvas_patch.permute(0, 4, 5, 1, 2, 3).contiguous()
|
281 |
+
# canvas_patch: b, h, w, 3, py, px
|
282 |
+
selected_canvas_patch = canvas_patch[:, patch_coord_y, patch_coord_x, :, :, :]
|
283 |
+
selected_foregrounds = foregrounds[:, patch_coord_y, patch_coord_x, :, :, :, :]
|
284 |
+
selected_alphas = alphas[:, patch_coord_y, patch_coord_x, :, :, :, :]
|
285 |
+
selected_decisions = decision[:, patch_coord_y, patch_coord_x, :, :, :, :]
|
286 |
+
for i in range(s):
|
287 |
+
cur_foreground = selected_foregrounds[:, :, :, i, :, :, :]
|
288 |
+
cur_alpha = selected_alphas[:, :, :, i, :, :, :]
|
289 |
+
cur_decision = selected_decisions[:, :, :, i, :, :, :]
|
290 |
+
selected_canvas_patch = cur_foreground * cur_alpha * cur_decision + selected_canvas_patch * (
|
291 |
+
1 - cur_alpha * cur_decision)
|
292 |
+
this_canvas = selected_canvas_patch.permute(0, 3, 1, 4, 2, 5).contiguous()
|
293 |
+
# this_canvas: b, 3, h_half, py, w_half, px
|
294 |
+
h_half = this_canvas.shape[2]
|
295 |
+
w_half = this_canvas.shape[4]
|
296 |
+
this_canvas = this_canvas.view(b, 3, h_half * patch_size_y, w_half * patch_size_x).contiguous()
|
297 |
+
# this_canvas: b, 3, h_half * py, w_half * px
|
298 |
+
return this_canvas
|
299 |
+
|
300 |
+
if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
301 |
+
canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x)
|
302 |
+
if not is_odd_y:
|
303 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
304 |
+
if not is_odd_x:
|
305 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
306 |
+
cur_canvas = canvas
|
307 |
+
|
308 |
+
if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
309 |
+
canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x)
|
310 |
+
canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
|
311 |
+
canvas = torch.cat([cur_canvas[:, :, -canvas.shape[2]:, :patch_size_x // 2], canvas], dim=3)
|
312 |
+
if is_odd_y:
|
313 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
314 |
+
if is_odd_x:
|
315 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
316 |
+
cur_canvas = canvas
|
317 |
+
|
318 |
+
if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
319 |
+
canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x)
|
320 |
+
canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
|
321 |
+
if is_odd_y:
|
322 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
|
323 |
+
if not is_odd_x:
|
324 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
325 |
+
cur_canvas = canvas
|
326 |
+
|
327 |
+
if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
328 |
+
canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x)
|
329 |
+
canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
|
330 |
+
if not is_odd_y:
|
331 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, -canvas.shape[3]:]], dim=2)
|
332 |
+
if is_odd_x:
|
333 |
+
canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
|
334 |
+
cur_canvas = canvas
|
335 |
+
|
336 |
+
cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
|
337 |
+
|
338 |
+
return cur_canvas
|
339 |
+
|
340 |
+
|
341 |
+
def read_img(img_path, img_type='RGB', h=None, w=None):
|
342 |
+
img = Image.open(img_path).convert(img_type)
|
343 |
+
if h is not None and w is not None:
|
344 |
+
img = img.resize((w, h), resample=Image.NEAREST)
|
345 |
+
img = np.array(img)
|
346 |
+
if img.ndim == 2:
|
347 |
+
img = np.expand_dims(img, axis=-1)
|
348 |
+
img = img.transpose((2, 0, 1))
|
349 |
+
img = torch.from_numpy(img).unsqueeze(0).float() / 255.
|
350 |
+
return img
|
351 |
+
|
352 |
+
|
353 |
+
def pad(img, H, W):
|
354 |
+
b, c, h, w = img.shape
|
355 |
+
pad_h = (H - h) // 2
|
356 |
+
pad_w = (W - w) // 2
|
357 |
+
remainder_h = (H - h) % 2
|
358 |
+
remainder_w = (W - w) % 2
|
359 |
+
img = torch.cat([torch.zeros((b, c, pad_h, w), device=img.device), img,
|
360 |
+
torch.zeros((b, c, pad_h + remainder_h, w), device=img.device)], dim=-2)
|
361 |
+
img = torch.cat([torch.zeros((b, c, H, pad_w), device=img.device), img,
|
362 |
+
torch.zeros((b, c, H, pad_w + remainder_w), device=img.device)], dim=-1)
|
363 |
+
return img
|
364 |
+
|
365 |
+
|
366 |
+
def crop(img, h, w):
|
367 |
+
H, W = img.shape[-2:]
|
368 |
+
pad_h = (H - h) // 2
|
369 |
+
pad_w = (W - w) // 2
|
370 |
+
remainder_h = (H - h) % 2
|
371 |
+
remainder_w = (W - w) % 2
|
372 |
+
img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
|
373 |
+
return img
|
374 |
+
|
375 |
+
|
376 |
+
def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
|
377 |
+
if not os.path.exists(output_dir):
|
378 |
+
os.mkdir(output_dir)
|
379 |
+
input_name = os.path.basename(input_path)
|
380 |
+
output_path = os.path.join(output_dir, input_name)
|
381 |
+
frame_dir = None
|
382 |
+
if need_animation:
|
383 |
+
if not serial:
|
384 |
+
print('It must be under serial mode if animation results are required, so serial flag is set to True!')
|
385 |
+
serial = True
|
386 |
+
frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')])
|
387 |
+
if not os.path.exists(frame_dir):
|
388 |
+
os.mkdir(frame_dir)
|
389 |
+
patch_size = 32
|
390 |
+
stroke_num = 8
|
391 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
392 |
+
net_g = network.Painter(5, stroke_num, 256, 8, 3, 3).to(device)
|
393 |
+
net_g.load_state_dict(torch.load(model_path))
|
394 |
+
net_g.eval()
|
395 |
+
for param in net_g.parameters():
|
396 |
+
param.requires_grad = False
|
397 |
+
|
398 |
+
brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(device)
|
399 |
+
brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(device)
|
400 |
+
meta_brushes = torch.cat(
|
401 |
+
[brush_large_vertical, brush_large_horizontal], dim=0)
|
402 |
+
|
403 |
+
with torch.no_grad():
|
404 |
+
original_img = read_img(input_path, 'RGB', resize_h, resize_w).to(device)
|
405 |
+
original_h, original_w = original_img.shape[-2:]
|
406 |
+
K = max(math.ceil(math.log2(max(original_h, original_w) / patch_size)), 0)
|
407 |
+
original_img_pad_size = patch_size * (2 ** K)
|
408 |
+
original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
|
409 |
+
final_result = torch.zeros_like(original_img_pad).to(device)
|
410 |
+
for layer in range(0, K + 1):
|
411 |
+
layer_size = patch_size * (2 ** layer)
|
412 |
+
img = F.interpolate(original_img_pad, (layer_size, layer_size))
|
413 |
+
result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
|
414 |
+
img_patch = F.unfold(img, (patch_size, patch_size), stride=(patch_size, patch_size))
|
415 |
+
result_patch = F.unfold(result, (patch_size, patch_size),
|
416 |
+
stride=(patch_size, patch_size))
|
417 |
+
# There are patch_num * patch_num patches in total
|
418 |
+
patch_num = (layer_size - patch_size) // patch_size + 1
|
419 |
+
|
420 |
+
# img_patch, result_patch: b, 3 * output_size * output_size, h * w
|
421 |
+
img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
|
422 |
+
result_patch = result_patch.permute(0, 2, 1).contiguous().view(
|
423 |
+
-1, 3, patch_size, patch_size).contiguous()
|
424 |
+
shape_param, stroke_decision = net_g(img_patch, result_patch)
|
425 |
+
stroke_decision = network.SignWithSigmoidGrad.apply(stroke_decision)
|
426 |
+
|
427 |
+
grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
|
428 |
+
img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
|
429 |
+
img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
|
430 |
+
color = F.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(
|
431 |
+
img_patch.shape[0], stroke_num, 3).contiguous()
|
432 |
+
stroke_param = torch.cat([shape_param, color], dim=-1)
|
433 |
+
# stroke_param: b * h * w, stroke_per_patch, param_per_stroke
|
434 |
+
# stroke_decision: b * h * w, stroke_per_patch, 1
|
435 |
+
param = stroke_param.view(1, patch_num, patch_num, stroke_num, 8).contiguous()
|
436 |
+
decision = stroke_decision.view(1, patch_num, patch_num, stroke_num).contiguous().bool()
|
437 |
+
# param: b, h, w, stroke_per_patch, 8
|
438 |
+
# decision: b, h, w, stroke_per_patch
|
439 |
+
param[..., :2] = param[..., :2] / 2 + 0.25
|
440 |
+
param[..., 2:4] = param[..., 2:4] / 2
|
441 |
+
if serial:
|
442 |
+
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
443 |
+
frame_dir, False, original_h, original_w)
|
444 |
+
else:
|
445 |
+
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
446 |
+
|
447 |
+
border_size = original_img_pad_size // (2 * patch_num)
|
448 |
+
img = F.interpolate(original_img_pad, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
|
449 |
+
result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
|
450 |
+
img = F.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2,
|
451 |
+
0, 0, 0, 0])
|
452 |
+
result = F.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2,
|
453 |
+
0, 0, 0, 0])
|
454 |
+
img_patch = F.unfold(img, (patch_size, patch_size), stride=(patch_size, patch_size))
|
455 |
+
result_patch = F.unfold(result, (patch_size, patch_size), stride=(patch_size, patch_size))
|
456 |
+
final_result = F.pad(final_result, [border_size, border_size, border_size, border_size, 0, 0, 0, 0])
|
457 |
+
h = (img.shape[2] - patch_size) // patch_size + 1
|
458 |
+
w = (img.shape[3] - patch_size) // patch_size + 1
|
459 |
+
# img_patch, result_patch: b, 3 * output_size * output_size, h * w
|
460 |
+
img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
|
461 |
+
result_patch = result_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
|
462 |
+
shape_param, stroke_decision = net_g(img_patch, result_patch)
|
463 |
+
|
464 |
+
grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
|
465 |
+
img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
|
466 |
+
img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
|
467 |
+
color = F.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(
|
468 |
+
img_patch.shape[0], stroke_num, 3).contiguous()
|
469 |
+
stroke_param = torch.cat([shape_param, color], dim=-1)
|
470 |
+
# stroke_param: b * h * w, stroke_per_patch, param_per_stroke
|
471 |
+
# stroke_decision: b * h * w, stroke_per_patch, 1
|
472 |
+
param = stroke_param.view(1, h, w, stroke_num, 8).contiguous()
|
473 |
+
decision = stroke_decision.view(1, h, w, stroke_num).contiguous().bool()
|
474 |
+
# param: b, h, w, stroke_per_patch, 8
|
475 |
+
# decision: b, h, w, stroke_per_patch
|
476 |
+
param[..., :2] = param[..., :2] / 2 + 0.25
|
477 |
+
param[..., 2:4] = param[..., 2:4] / 2
|
478 |
+
if serial:
|
479 |
+
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
480 |
+
frame_dir, True, original_h, original_w)
|
481 |
+
else:
|
482 |
+
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
483 |
+
final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
|
484 |
+
|
485 |
+
final_result = crop(final_result, original_h, original_w)
|
486 |
+
save_img(final_result[0], output_path)
|
487 |
+
|
488 |
+
|
489 |
+
if __name__ == '__main__':
|
490 |
+
main(input_path='input/chicago.jpg',
|
491 |
+
model_path='model.pth',
|
492 |
+
output_dir='output/',
|
493 |
+
need_animation=False, # whether need intermediate results for animation.
|
494 |
+
resize_h=None, # resize original input to this size. None means do not resize.
|
495 |
+
resize_w=None, # resize original input to this size. None means do not resize.
|
496 |
+
serial=False) # if need animation, serial must be True.
|
inference/input/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
inference/input/temp.txt
ADDED
File without changes
|
inference/morphology.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Erosion2d(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, m=1):
|
9 |
+
super(Erosion2d, self).__init__()
|
10 |
+
self.m = m
|
11 |
+
self.pad = [m, m, m, m]
|
12 |
+
self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
batch_size, c, h, w = x.shape
|
16 |
+
x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
|
17 |
+
channel = self.unfold(x_pad).view(batch_size, c, -1, h, w)
|
18 |
+
result = torch.min(channel, dim=2)[0]
|
19 |
+
return result
|
20 |
+
|
21 |
+
|
22 |
+
def erosion(x, m=1):
|
23 |
+
b, c, h, w = x.shape
|
24 |
+
x_pad = F.pad(x, pad=[m, m, m, m], mode='constant', value=1e9)
|
25 |
+
channel = nn.functional.unfold(x_pad, 2 * m + 1, padding=0, stride=1).view(b, c, -1, h, w)
|
26 |
+
result = torch.min(channel, dim=2)[0]
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
class Dilation2d(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, m=1):
|
33 |
+
super(Dilation2d, self).__init__()
|
34 |
+
self.m = m
|
35 |
+
self.pad = [m, m, m, m]
|
36 |
+
self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
batch_size, c, h, w = x.shape
|
40 |
+
x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
|
41 |
+
channel = self.unfold(x_pad).view(batch_size, c, -1, h, w)
|
42 |
+
result = torch.max(channel, dim=2)[0]
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
def dilation(x, m=1):
|
47 |
+
b, c, h, w = x.shape
|
48 |
+
x_pad = F.pad(x, pad=[m, m, m, m], mode='constant', value=-1e9)
|
49 |
+
channel = nn.functional.unfold(x_pad, 2 * m + 1, padding=0, stride=1).view(b, c, -1, h, w)
|
50 |
+
result = torch.max(channel, dim=2)[0]
|
51 |
+
return result
|
inference/network.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class SignWithSigmoidGrad(torch.autograd.Function):
|
6 |
+
|
7 |
+
@staticmethod
|
8 |
+
def forward(ctx, x):
|
9 |
+
result = (x > 0).float()
|
10 |
+
sigmoid_result = torch.sigmoid(x)
|
11 |
+
ctx.save_for_backward(sigmoid_result)
|
12 |
+
return result
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def backward(ctx, grad_result):
|
16 |
+
(sigmoid_result,) = ctx.saved_tensors
|
17 |
+
if ctx.needs_input_grad[0]:
|
18 |
+
grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)
|
19 |
+
else:
|
20 |
+
grad_input = None
|
21 |
+
return grad_input
|
22 |
+
|
23 |
+
|
24 |
+
class Painter(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
|
27 |
+
super().__init__()
|
28 |
+
self.enc_img = nn.Sequential(
|
29 |
+
nn.ReflectionPad2d(1),
|
30 |
+
nn.Conv2d(3, 32, 3, 1),
|
31 |
+
nn.BatchNorm2d(32),
|
32 |
+
nn.ReLU(True),
|
33 |
+
nn.ReflectionPad2d(1),
|
34 |
+
nn.Conv2d(32, 64, 3, 2),
|
35 |
+
nn.BatchNorm2d(64),
|
36 |
+
nn.ReLU(True),
|
37 |
+
nn.ReflectionPad2d(1),
|
38 |
+
nn.Conv2d(64, 128, 3, 2),
|
39 |
+
nn.BatchNorm2d(128),
|
40 |
+
nn.ReLU(True))
|
41 |
+
self.enc_canvas = nn.Sequential(
|
42 |
+
nn.ReflectionPad2d(1),
|
43 |
+
nn.Conv2d(3, 32, 3, 1),
|
44 |
+
nn.BatchNorm2d(32),
|
45 |
+
nn.ReLU(True),
|
46 |
+
nn.ReflectionPad2d(1),
|
47 |
+
nn.Conv2d(32, 64, 3, 2),
|
48 |
+
nn.BatchNorm2d(64),
|
49 |
+
nn.ReLU(True),
|
50 |
+
nn.ReflectionPad2d(1),
|
51 |
+
nn.Conv2d(64, 128, 3, 2),
|
52 |
+
nn.BatchNorm2d(128),
|
53 |
+
nn.ReLU(True))
|
54 |
+
self.conv = nn.Conv2d(128 * 2, hidden_dim, 1)
|
55 |
+
self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
|
56 |
+
self.linear_param = nn.Sequential(
|
57 |
+
nn.Linear(hidden_dim, hidden_dim),
|
58 |
+
nn.ReLU(True),
|
59 |
+
nn.Linear(hidden_dim, hidden_dim),
|
60 |
+
nn.ReLU(True),
|
61 |
+
nn.Linear(hidden_dim, param_per_stroke))
|
62 |
+
self.linear_decider = nn.Linear(hidden_dim, 1)
|
63 |
+
self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim))
|
64 |
+
self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
|
65 |
+
self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
|
66 |
+
|
67 |
+
def forward(self, img, canvas):
|
68 |
+
b, _, H, W = img.shape
|
69 |
+
img_feat = self.enc_img(img)
|
70 |
+
canvas_feat = self.enc_canvas(canvas)
|
71 |
+
h, w = img_feat.shape[-2:]
|
72 |
+
feat = torch.cat([img_feat, canvas_feat], dim=1)
|
73 |
+
feat_conv = self.conv(feat)
|
74 |
+
|
75 |
+
pos_embed = torch.cat([
|
76 |
+
self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1),
|
77 |
+
self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1),
|
78 |
+
], dim=-1).flatten(0, 1).unsqueeze(1)
|
79 |
+
hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(),
|
80 |
+
self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1))
|
81 |
+
hidden_state = hidden_state.permute(1, 0, 2).contiguous()
|
82 |
+
param = self.linear_param(hidden_state)
|
83 |
+
decision = self.linear_decider(hidden_state)
|
84 |
+
return param, decision
|
train/brush/brush_large_horizontal.png
ADDED
train/brush/brush_large_vertical.png
ADDED
train/brush/brush_small_horizontal.png
ADDED
train/brush/brush_small_vertical.png
ADDED
train/data/__init__.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from data.base_dataset import BaseDataset
|
16 |
+
|
17 |
+
|
18 |
+
def find_dataset_using_name(dataset_name):
|
19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
20 |
+
|
21 |
+
In the file, the class called DatasetNameDataset() will
|
22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
23 |
+
and it is case-insensitive.
|
24 |
+
"""
|
25 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
27 |
+
|
28 |
+
dataset = None
|
29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
30 |
+
for name, cls in datasetlib.__dict__.items():
|
31 |
+
if name.lower() == target_dataset_name.lower() \
|
32 |
+
and issubclass(cls, BaseDataset):
|
33 |
+
dataset = cls
|
34 |
+
|
35 |
+
if dataset is None:
|
36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
37 |
+
|
38 |
+
return dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_option_setter(dataset_name):
|
42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
44 |
+
return dataset_class.modify_commandline_options
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(opt):
|
48 |
+
"""Create a dataset given the option.
|
49 |
+
|
50 |
+
This function wraps the class CustomDatasetDataLoader.
|
51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> from data import create_dataset
|
55 |
+
>>> dataset = create_dataset(opt)
|
56 |
+
"""
|
57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
58 |
+
dataset = data_loader.load_data()
|
59 |
+
return dataset
|
60 |
+
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
75 |
+
self.dataloader = torch.utils.data.DataLoader(
|
76 |
+
self.dataset,
|
77 |
+
batch_size=opt.batch_size,
|
78 |
+
shuffle=not opt.serial_batches,
|
79 |
+
num_workers=int(opt.num_threads),
|
80 |
+
drop_last=True)
|
81 |
+
|
82 |
+
def load_data(self):
|
83 |
+
return self
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
"""Return the number of data in the dataset"""
|
87 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
88 |
+
|
89 |
+
def __iter__(self):
|
90 |
+
"""Return a batch of data"""
|
91 |
+
for i, data in enumerate(self.dataloader):
|
92 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
93 |
+
break
|
94 |
+
yield data
|
train/data/base_dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
self.root = opt.dataroot
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def modify_commandline_options(parser, is_train):
|
34 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
parser -- original option parser
|
38 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
the modified parser.
|
42 |
+
"""
|
43 |
+
return parser
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def __len__(self):
|
47 |
+
"""Return the total number of images in the dataset."""
|
48 |
+
return 0
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def __getitem__(self, index):
|
52 |
+
"""Return a data point and its metadata information.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
index - - a random integer for data indexing
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
59 |
+
"""
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
def get_params(opt, size):
|
64 |
+
w, h = size
|
65 |
+
new_h = h
|
66 |
+
new_w = w
|
67 |
+
if opt.preprocess == 'resize_and_crop':
|
68 |
+
new_h = new_w = opt.load_size
|
69 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
70 |
+
new_w = opt.load_size
|
71 |
+
new_h = opt.load_size * h // w
|
72 |
+
|
73 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
74 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
75 |
+
|
76 |
+
flip = random.random() > 0.5
|
77 |
+
|
78 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
79 |
+
|
80 |
+
|
81 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
82 |
+
transform_list = []
|
83 |
+
if grayscale:
|
84 |
+
transform_list.append(transforms.Grayscale(1))
|
85 |
+
if 'resize' in opt.preprocess:
|
86 |
+
osize = [opt.load_size, opt.load_size]
|
87 |
+
transform_list.append(transforms.Resize(osize, method))
|
88 |
+
elif 'scale_width' in opt.preprocess:
|
89 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
90 |
+
|
91 |
+
if 'crop' in opt.preprocess:
|
92 |
+
if params is None:
|
93 |
+
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
94 |
+
else:
|
95 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
96 |
+
|
97 |
+
if opt.preprocess == 'none':
|
98 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
99 |
+
|
100 |
+
if not opt.no_flip:
|
101 |
+
if params is None:
|
102 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
103 |
+
elif params['flip']:
|
104 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
105 |
+
|
106 |
+
if convert:
|
107 |
+
transform_list += [transforms.ToTensor()]
|
108 |
+
return transforms.Compose(transform_list)
|
109 |
+
|
110 |
+
|
111 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
112 |
+
ow, oh = img.size
|
113 |
+
h = int(round(oh / base) * base)
|
114 |
+
w = int(round(ow / base) * base)
|
115 |
+
if h == oh and w == ow:
|
116 |
+
return img
|
117 |
+
|
118 |
+
__print_size_warning(ow, oh, w, h)
|
119 |
+
return img.resize((w, h), method)
|
120 |
+
|
121 |
+
|
122 |
+
def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
|
123 |
+
ow, oh = img.size
|
124 |
+
if ow == target_size and oh >= crop_size:
|
125 |
+
return img
|
126 |
+
w = target_size
|
127 |
+
h = int(max(target_size * oh / ow, crop_size))
|
128 |
+
return img.resize((w, h), method)
|
129 |
+
|
130 |
+
|
131 |
+
def __crop(img, pos, size):
|
132 |
+
ow, oh = img.size
|
133 |
+
x1, y1 = pos
|
134 |
+
tw = th = size
|
135 |
+
if (ow > tw or oh > th):
|
136 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def __flip(img, flip):
|
141 |
+
if flip:
|
142 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
143 |
+
return img
|
144 |
+
|
145 |
+
|
146 |
+
def __print_size_warning(ow, oh, w, h):
|
147 |
+
"""Print warning information about image size(only print once)"""
|
148 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
149 |
+
print("The image size needs to be a multiple of 4. "
|
150 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
151 |
+
"(%d, %d). This adjustment will be done to all images "
|
152 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
153 |
+
__print_size_warning.has_printed = True
|
train/data/null_dataset.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.base_dataset import BaseDataset
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class NullDataset(BaseDataset):
|
6 |
+
|
7 |
+
def __init__(self, opt):
|
8 |
+
BaseDataset.__init__(self, opt)
|
9 |
+
|
10 |
+
def __getitem__(self, index):
|
11 |
+
return {'A_paths': os.path.join(self.opt.dataroot, '%d.jpg' % index)}
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
"""Return the total number of images in the dataset."""
|
15 |
+
return self.opt.max_dataset_size
|
train/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
train/models/base_model.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from . import networks
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(ABC):
|
9 |
+
"""This class is an abstract base class (ABC) for models.
|
10 |
+
To create a subclass, you need to implement the following five functions:
|
11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
13 |
+
-- <forward>: produce intermediate results.
|
14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
"""Initialize the BaseModel class.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
23 |
+
|
24 |
+
When creating your custom class, you need to implement your own initialization.
|
25 |
+
In this function, you should first call <BaseModel.__init__(self, opt)>
|
26 |
+
Then, you need to define four lists:
|
27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
28 |
+
-- self.model_names (str list): define networks used in our training.
|
29 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
31 |
+
"""
|
32 |
+
self.opt = opt
|
33 |
+
self.gpu_ids = opt.gpu_ids
|
34 |
+
self.isTrain = opt.isTrain
|
35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
36 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
37 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
self.loss_names = []
|
40 |
+
self.model_names = []
|
41 |
+
self.visual_names = []
|
42 |
+
self.optimizers = []
|
43 |
+
self.image_paths = []
|
44 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def modify_commandline_options(parser, is_train):
|
48 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
parser -- original option parser
|
52 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
the modified parser.
|
56 |
+
"""
|
57 |
+
return parser
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def set_input(self, input):
|
61 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
62 |
+
|
63 |
+
Parameters:
|
64 |
+
input (dict): includes the data itself and its metadata information.
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
@abstractmethod
|
69 |
+
def forward(self):
|
70 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
71 |
+
pass
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def optimize_parameters(self):
|
75 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
76 |
+
pass
|
77 |
+
|
78 |
+
def setup(self, opt):
|
79 |
+
"""Load and print networks; create schedulers
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
83 |
+
"""
|
84 |
+
if self.isTrain:
|
85 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
86 |
+
if not self.isTrain or opt.continue_train:
|
87 |
+
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
88 |
+
self.load_networks(load_suffix)
|
89 |
+
self.print_networks(opt.verbose)
|
90 |
+
|
91 |
+
def eval(self):
|
92 |
+
"""Make models eval mode during test time"""
|
93 |
+
for name in self.model_names:
|
94 |
+
if isinstance(name, str):
|
95 |
+
net = getattr(self, 'net_' + name)
|
96 |
+
net.eval()
|
97 |
+
|
98 |
+
def test(self):
|
99 |
+
"""Forward function used in test time.
|
100 |
+
|
101 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
102 |
+
It also calls <compute_visuals> to produce additional visualization results
|
103 |
+
"""
|
104 |
+
with torch.no_grad():
|
105 |
+
self.forward()
|
106 |
+
self.compute_visuals()
|
107 |
+
|
108 |
+
def compute_visuals(self):
|
109 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
110 |
+
pass
|
111 |
+
|
112 |
+
def get_image_paths(self):
|
113 |
+
""" Return image paths that are used to load current data"""
|
114 |
+
return self.image_paths
|
115 |
+
|
116 |
+
def update_learning_rate(self):
|
117 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
118 |
+
old_lr = self.optimizers[0].param_groups[0]['lr']
|
119 |
+
for scheduler in self.schedulers:
|
120 |
+
if self.opt.lr_policy == 'plateau':
|
121 |
+
scheduler.step(self.metric)
|
122 |
+
else:
|
123 |
+
scheduler.step()
|
124 |
+
|
125 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
126 |
+
print('learning rate %.7f -> %.7f' % (old_lr, lr))
|
127 |
+
|
128 |
+
def get_current_visuals(self):
|
129 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
130 |
+
visual_ret = OrderedDict()
|
131 |
+
for name in self.visual_names:
|
132 |
+
if isinstance(name, str):
|
133 |
+
visual_ret[name] = getattr(self, name)
|
134 |
+
return visual_ret
|
135 |
+
|
136 |
+
def get_current_losses(self):
|
137 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
138 |
+
errors_ret = OrderedDict()
|
139 |
+
for name in self.loss_names:
|
140 |
+
if isinstance(name, str):
|
141 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
142 |
+
return errors_ret
|
143 |
+
|
144 |
+
def save_networks(self, epoch):
|
145 |
+
"""Save all the networks to the disk.
|
146 |
+
|
147 |
+
Parameters:
|
148 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
149 |
+
"""
|
150 |
+
for name in self.model_names:
|
151 |
+
if isinstance(name, str):
|
152 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
153 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
154 |
+
net = getattr(self, 'net_' + name)
|
155 |
+
|
156 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
157 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
158 |
+
net.cuda(self.gpu_ids[0])
|
159 |
+
else:
|
160 |
+
torch.save(net.cpu().state_dict(), save_path)
|
161 |
+
|
162 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
163 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
164 |
+
key = keys[i]
|
165 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
166 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
167 |
+
(key == 'running_mean' or key == 'running_var'):
|
168 |
+
if getattr(module, key) is None:
|
169 |
+
state_dict.pop('.'.join(keys))
|
170 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
171 |
+
(key == 'num_batches_tracked'):
|
172 |
+
state_dict.pop('.'.join(keys))
|
173 |
+
else:
|
174 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
175 |
+
|
176 |
+
def load_networks(self, epoch):
|
177 |
+
"""Load all the networks from the disk.
|
178 |
+
|
179 |
+
Parameters:
|
180 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
181 |
+
"""
|
182 |
+
for name in self.model_names:
|
183 |
+
if isinstance(name, str):
|
184 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
185 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
186 |
+
net = getattr(self, 'net_' + name)
|
187 |
+
if isinstance(net, torch.nn.DataParallel):
|
188 |
+
net = net.module
|
189 |
+
print('loading the model from %s' % load_path)
|
190 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
191 |
+
# GitHub source), you can remove str() on self.device
|
192 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
193 |
+
if hasattr(state_dict, '_metadata'):
|
194 |
+
del state_dict._metadata
|
195 |
+
|
196 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
197 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
198 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
199 |
+
net.load_state_dict(state_dict)
|
200 |
+
|
201 |
+
def print_networks(self, verbose):
|
202 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
203 |
+
|
204 |
+
Parameters:
|
205 |
+
verbose (bool) -- if verbose: print the network architecture
|
206 |
+
"""
|
207 |
+
print('---------- Networks initialized -------------')
|
208 |
+
for name in self.model_names:
|
209 |
+
if isinstance(name, str):
|
210 |
+
net = getattr(self, 'net_' + name)
|
211 |
+
num_params = 0
|
212 |
+
for param in net.parameters():
|
213 |
+
num_params += param.numel()
|
214 |
+
if verbose:
|
215 |
+
print(net)
|
216 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
217 |
+
print('-----------------------------------------------')
|
218 |
+
|
219 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
220 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
221 |
+
Parameters:
|
222 |
+
nets (network list) -- a list of networks
|
223 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
224 |
+
"""
|
225 |
+
if not isinstance(nets, list):
|
226 |
+
nets = [nets]
|
227 |
+
for net in nets:
|
228 |
+
if net is not None:
|
229 |
+
for param in net.parameters():
|
230 |
+
param.requires_grad = requires_grad
|
train/models/networks.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
from torch.optim import lr_scheduler
|
5 |
+
|
6 |
+
|
7 |
+
def get_scheduler(optimizer, opt):
|
8 |
+
if opt.lr_policy == 'linear':
|
9 |
+
def lambda_rule(epoch):
|
10 |
+
# lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
11 |
+
lr_l = 0.3 ** max(0, (epoch + opt.epoch_count - opt.n_epochs) // 5)
|
12 |
+
return lr_l
|
13 |
+
|
14 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
15 |
+
elif opt.lr_policy == 'step':
|
16 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
17 |
+
elif opt.lr_policy == 'plateau':
|
18 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
19 |
+
elif opt.lr_policy == 'cosine':
|
20 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
21 |
+
else:
|
22 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
23 |
+
return scheduler
|
24 |
+
|
25 |
+
|
26 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
27 |
+
def init_func(m):
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
30 |
+
if init_type == 'normal':
|
31 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
32 |
+
elif init_type == 'xavier':
|
33 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
34 |
+
elif init_type == 'kaiming':
|
35 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
36 |
+
elif init_type == 'orthogonal':
|
37 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
38 |
+
else:
|
39 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
40 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
41 |
+
init.constant_(m.bias.data, 0.0)
|
42 |
+
elif classname.find('BatchNorm2d') != -1:
|
43 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
44 |
+
init.constant_(m.bias.data, 0.0)
|
45 |
+
|
46 |
+
print('initialize network with %s' % init_type)
|
47 |
+
net.apply(init_func)
|
48 |
+
|
49 |
+
|
50 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=()):
|
51 |
+
if len(gpu_ids) > 0:
|
52 |
+
assert (torch.cuda.is_available())
|
53 |
+
net.to(gpu_ids[0])
|
54 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
55 |
+
init_weights(net, init_type, init_gain=init_gain)
|
56 |
+
return net
|
57 |
+
|
58 |
+
|
59 |
+
class SignWithSigmoidGrad(torch.autograd.Function):
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def forward(ctx, x):
|
63 |
+
result = (x > 0).float()
|
64 |
+
sigmoid_result = torch.sigmoid(x)
|
65 |
+
ctx.save_for_backward(sigmoid_result)
|
66 |
+
return result
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def backward(ctx, grad_result):
|
70 |
+
(sigmoid_result,) = ctx.saved_tensors
|
71 |
+
if ctx.needs_input_grad[0]:
|
72 |
+
grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)
|
73 |
+
else:
|
74 |
+
grad_input = None
|
75 |
+
return grad_input
|
76 |
+
|
77 |
+
|
78 |
+
class Painter(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
|
81 |
+
super().__init__()
|
82 |
+
self.enc_img = nn.Sequential(
|
83 |
+
nn.ReflectionPad2d(1),
|
84 |
+
nn.Conv2d(3, 32, 3, 1),
|
85 |
+
nn.BatchNorm2d(32),
|
86 |
+
nn.ReLU(True),
|
87 |
+
nn.ReflectionPad2d(1),
|
88 |
+
nn.Conv2d(32, 64, 3, 2),
|
89 |
+
nn.BatchNorm2d(64),
|
90 |
+
nn.ReLU(True),
|
91 |
+
nn.ReflectionPad2d(1),
|
92 |
+
nn.Conv2d(64, 128, 3, 2),
|
93 |
+
nn.BatchNorm2d(128),
|
94 |
+
nn.ReLU(True))
|
95 |
+
self.enc_canvas = nn.Sequential(
|
96 |
+
nn.ReflectionPad2d(1),
|
97 |
+
nn.Conv2d(3, 32, 3, 1),
|
98 |
+
nn.BatchNorm2d(32),
|
99 |
+
nn.ReLU(True),
|
100 |
+
nn.ReflectionPad2d(1),
|
101 |
+
nn.Conv2d(32, 64, 3, 2),
|
102 |
+
nn.BatchNorm2d(64),
|
103 |
+
nn.ReLU(True),
|
104 |
+
nn.ReflectionPad2d(1),
|
105 |
+
nn.Conv2d(64, 128, 3, 2),
|
106 |
+
nn.BatchNorm2d(128),
|
107 |
+
nn.ReLU(True))
|
108 |
+
self.conv = nn.Conv2d(128 * 2, hidden_dim, 1)
|
109 |
+
self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
|
110 |
+
self.linear_param = nn.Sequential(
|
111 |
+
nn.Linear(hidden_dim, hidden_dim),
|
112 |
+
nn.ReLU(True),
|
113 |
+
nn.Linear(hidden_dim, hidden_dim),
|
114 |
+
nn.ReLU(True),
|
115 |
+
nn.Linear(hidden_dim, param_per_stroke))
|
116 |
+
self.linear_decider = nn.Linear(hidden_dim, 1)
|
117 |
+
self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim))
|
118 |
+
self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
|
119 |
+
self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
|
120 |
+
|
121 |
+
def forward(self, img, canvas):
|
122 |
+
b, _, H, W = img.shape
|
123 |
+
img_feat = self.enc_img(img)
|
124 |
+
canvas_feat = self.enc_canvas(canvas)
|
125 |
+
h, w = img_feat.shape[-2:]
|
126 |
+
feat = torch.cat([img_feat, canvas_feat], dim=1)
|
127 |
+
feat_conv = self.conv(feat)
|
128 |
+
|
129 |
+
pos_embed = torch.cat([
|
130 |
+
self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1),
|
131 |
+
self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1),
|
132 |
+
], dim=-1).flatten(0, 1).unsqueeze(1)
|
133 |
+
hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(),
|
134 |
+
self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1))
|
135 |
+
hidden_state = hidden_state.permute(1, 0, 2).contiguous()
|
136 |
+
param = self.linear_param(hidden_state)
|
137 |
+
s = hidden_state.shape[1]
|
138 |
+
grid = param[:, :, :2].view(b * s, 1, 1, 2).contiguous()
|
139 |
+
img_temp = img.unsqueeze(1).contiguous().repeat(1, s, 1, 1, 1).view(b * s, 3, H, W).contiguous()
|
140 |
+
color = nn.functional.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(b, s, 3).contiguous()
|
141 |
+
decision = self.linear_decider(hidden_state)
|
142 |
+
return torch.cat([param, color, color, torch.rand(b, s, 1, device=img.device)], dim=-1), decision
|
143 |
+
|
train/models/painter_model.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .base_model import BaseModel
|
4 |
+
from . import networks
|
5 |
+
from util import morphology
|
6 |
+
from scipy.optimize import linear_sum_assignment
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class PainterModel(BaseModel):
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def modify_commandline_options(parser, is_train=True):
|
14 |
+
parser.set_defaults(dataset_mode='null')
|
15 |
+
parser.add_argument('--used_strokes', type=int, default=8,
|
16 |
+
help='actually generated strokes number')
|
17 |
+
parser.add_argument('--num_blocks', type=int, default=3,
|
18 |
+
help='number of transformer blocks for stroke generator')
|
19 |
+
parser.add_argument('--lambda_w', type=float, default=10.0, help='weight for w loss of stroke shape')
|
20 |
+
parser.add_argument('--lambda_pixel', type=float, default=10.0, help='weight for pixel-level L1 loss')
|
21 |
+
parser.add_argument('--lambda_gt', type=float, default=1.0, help='weight for ground-truth loss')
|
22 |
+
parser.add_argument('--lambda_decision', type=float, default=10.0, help='weight for stroke decision loss')
|
23 |
+
parser.add_argument('--lambda_recall', type=float, default=10.0, help='weight of recall for stroke decision loss')
|
24 |
+
return parser
|
25 |
+
|
26 |
+
def __init__(self, opt):
|
27 |
+
BaseModel.__init__(self, opt)
|
28 |
+
self.loss_names = ['pixel', 'gt', 'w', 'decision']
|
29 |
+
self.visual_names = ['old', 'render', 'rec']
|
30 |
+
self.model_names = ['g']
|
31 |
+
self.d = 12 # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
|
32 |
+
self.d_shape = 5
|
33 |
+
|
34 |
+
def read_img(img_path, img_type='RGB'):
|
35 |
+
img = Image.open(img_path).convert(img_type)
|
36 |
+
img = np.array(img)
|
37 |
+
if img.ndim == 2:
|
38 |
+
img = np.expand_dims(img, axis=-1)
|
39 |
+
img = img.transpose((2, 0, 1))
|
40 |
+
img = torch.from_numpy(img).unsqueeze(0).float() / 255.
|
41 |
+
return img
|
42 |
+
|
43 |
+
brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(self.device)
|
44 |
+
brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(self.device)
|
45 |
+
self.meta_brushes = torch.cat(
|
46 |
+
[brush_large_vertical, brush_large_horizontal], dim=0)
|
47 |
+
net_g = networks.Painter(self.d_shape, opt.used_strokes, opt.ngf,
|
48 |
+
n_enc_layers=opt.num_blocks, n_dec_layers=opt.num_blocks)
|
49 |
+
self.net_g = networks.init_net(net_g, opt.init_type, opt.init_gain, self.gpu_ids)
|
50 |
+
self.old = None
|
51 |
+
self.render = None
|
52 |
+
self.rec = None
|
53 |
+
self.gt_param = None
|
54 |
+
self.pred_param = None
|
55 |
+
self.gt_decision = None
|
56 |
+
self.pred_decision = None
|
57 |
+
self.patch_size = 32
|
58 |
+
self.loss_pixel = torch.tensor(0., device=self.device)
|
59 |
+
self.loss_gt = torch.tensor(0., device=self.device)
|
60 |
+
self.loss_w = torch.tensor(0., device=self.device)
|
61 |
+
self.loss_decision = torch.tensor(0., device=self.device)
|
62 |
+
self.criterion_pixel = torch.nn.L1Loss().to(self.device)
|
63 |
+
self.criterion_decision = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(opt.lambda_recall)).to(self.device)
|
64 |
+
if self.isTrain:
|
65 |
+
self.optimizer = torch.optim.Adam(self.net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
66 |
+
self.optimizers.append(self.optimizer)
|
67 |
+
|
68 |
+
def param2stroke(self, param, H, W):
|
69 |
+
# param: b, 12
|
70 |
+
b = param.shape[0]
|
71 |
+
param_list = torch.split(param, 1, dim=1)
|
72 |
+
x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
|
73 |
+
R0, G0, B0, R2, G2, B2, _ = param_list[5:]
|
74 |
+
sin_theta = torch.sin(torch.acos(torch.tensor(-1., device=param.device)) * theta)
|
75 |
+
cos_theta = torch.cos(torch.acos(torch.tensor(-1., device=param.device)) * theta)
|
76 |
+
index = torch.full((b,), -1, device=param.device)
|
77 |
+
index[h > w] = 0
|
78 |
+
index[h <= w] = 1
|
79 |
+
brush = self.meta_brushes[index.long()]
|
80 |
+
alphas = torch.cat([brush, brush, brush], dim=1)
|
81 |
+
alphas = (alphas > 0).float()
|
82 |
+
t = torch.arange(0, brush.shape[2], device=param.device).unsqueeze(0) / brush.shape[2]
|
83 |
+
color_map = torch.stack([R0 * (1 - t) + R2 * t, G0 * (1 - t) + G2 * t, B0 * (1 - t) + B2 * t], dim=1)
|
84 |
+
color_map = color_map.unsqueeze(-1).repeat(1, 1, 1, brush.shape[3])
|
85 |
+
brush = brush * color_map
|
86 |
+
|
87 |
+
warp_00 = cos_theta / w
|
88 |
+
warp_01 = sin_theta * H / (W * w)
|
89 |
+
warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
|
90 |
+
warp_10 = -sin_theta * W / (H * h)
|
91 |
+
warp_11 = cos_theta / h
|
92 |
+
warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
|
93 |
+
warp_0 = torch.stack([warp_00, warp_01, warp_02], dim=1)
|
94 |
+
warp_1 = torch.stack([warp_10, warp_11, warp_12], dim=1)
|
95 |
+
warp = torch.stack([warp_0, warp_1], dim=1)
|
96 |
+
grid = torch.nn.functional.affine_grid(warp, torch.Size((b, 3, H, W)), align_corners=False)
|
97 |
+
brush = torch.nn.functional.grid_sample(brush, grid, align_corners=False)
|
98 |
+
alphas = torch.nn.functional.grid_sample(alphas, grid, align_corners=False)
|
99 |
+
|
100 |
+
return brush, alphas
|
101 |
+
|
102 |
+
def set_input(self, input_dict):
|
103 |
+
self.image_paths = input_dict['A_paths']
|
104 |
+
with torch.no_grad():
|
105 |
+
old_param = torch.rand(self.opt.batch_size // 4, self.opt.used_strokes, self.d, device=self.device)
|
106 |
+
old_param[:, :, :4] = old_param[:, :, :4] * 0.5 + 0.2
|
107 |
+
old_param[:, :, -4:-1] = old_param[:, :, -7:-4]
|
108 |
+
old_param = old_param.view(-1, self.d).contiguous()
|
109 |
+
foregrounds, alphas = self.param2stroke(old_param, self.patch_size * 2, self.patch_size * 2)
|
110 |
+
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
|
111 |
+
alphas = morphology.Erosion2d(m=1)(alphas)
|
112 |
+
foregrounds = foregrounds.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
|
113 |
+
self.patch_size * 2).contiguous()
|
114 |
+
alphas = alphas.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
|
115 |
+
self.patch_size * 2).contiguous()
|
116 |
+
old = torch.zeros(self.opt.batch_size // 4, 3, self.patch_size * 2, self.patch_size * 2, device=self.device)
|
117 |
+
for i in range(self.opt.used_strokes):
|
118 |
+
foreground = foregrounds[:, i, :, :, :]
|
119 |
+
alpha = alphas[:, i, :, :, :]
|
120 |
+
old = foreground * alpha + old * (1 - alpha)
|
121 |
+
old = old.view(self.opt.batch_size // 4, 3, 2, self.patch_size, 2, self.patch_size).contiguous()
|
122 |
+
old = old.permute(0, 2, 4, 1, 3, 5).contiguous()
|
123 |
+
self.old = old.view(self.opt.batch_size, 3, self.patch_size, self.patch_size).contiguous()
|
124 |
+
|
125 |
+
gt_param = torch.rand(self.opt.batch_size, self.opt.used_strokes, self.d, device=self.device)
|
126 |
+
gt_param[:, :, :4] = gt_param[:, :, :4] * 0.5 + 0.2
|
127 |
+
gt_param[:, :, -4:-1] = gt_param[:, :, -7:-4]
|
128 |
+
self.gt_param = gt_param[:, :, :self.d_shape]
|
129 |
+
gt_param = gt_param.view(-1, self.d).contiguous()
|
130 |
+
foregrounds, alphas = self.param2stroke(gt_param, self.patch_size, self.patch_size)
|
131 |
+
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
|
132 |
+
alphas = morphology.Erosion2d(m=1)(alphas)
|
133 |
+
foregrounds = foregrounds.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
|
134 |
+
self.patch_size).contiguous()
|
135 |
+
alphas = alphas.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
|
136 |
+
self.patch_size).contiguous()
|
137 |
+
self.render = self.old.clone()
|
138 |
+
gt_decision = torch.ones(self.opt.batch_size, self.opt.used_strokes, device=self.device)
|
139 |
+
for i in range(self.opt.used_strokes):
|
140 |
+
foreground = foregrounds[:, i, :, :, :]
|
141 |
+
alpha = alphas[:, i, :, :, :]
|
142 |
+
for j in range(i):
|
143 |
+
iou = (torch.sum(alpha * alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5) / (
|
144 |
+
torch.sum(alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5)
|
145 |
+
gt_decision[:, i] = ((iou < 0.75) | (~gt_decision[:, j].bool())).float() * gt_decision[:, i]
|
146 |
+
decision = gt_decision[:, i].view(self.opt.batch_size, 1, 1, 1).contiguous()
|
147 |
+
self.render = foreground * alpha * decision + self.render * (1 - alpha * decision)
|
148 |
+
self.gt_decision = gt_decision
|
149 |
+
|
150 |
+
def forward(self):
|
151 |
+
param, decisions = self.net_g(self.render, self.old)
|
152 |
+
# stroke_param: b, stroke_per_patch, param_per_stroke
|
153 |
+
# decision: b, stroke_per_patch, 1
|
154 |
+
self.pred_decision = decisions.view(-1, self.opt.used_strokes).contiguous()
|
155 |
+
self.pred_param = param[:, :, :self.d_shape]
|
156 |
+
param = param.view(-1, self.d).contiguous()
|
157 |
+
foregrounds, alphas = self.param2stroke(param, self.patch_size, self.patch_size)
|
158 |
+
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
|
159 |
+
alphas = morphology.Erosion2d(m=1)(alphas)
|
160 |
+
# foreground, alpha: b * stroke_per_patch, 3, output_size, output_size
|
161 |
+
foregrounds = foregrounds.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
|
162 |
+
alphas = alphas.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
|
163 |
+
# foreground, alpha: b, stroke_per_patch, 3, output_size, output_size
|
164 |
+
decisions = networks.SignWithSigmoidGrad.apply(decisions.view(-1, self.opt.used_strokes, 1, 1, 1).contiguous())
|
165 |
+
self.rec = self.old.clone()
|
166 |
+
for j in range(foregrounds.shape[1]):
|
167 |
+
foreground = foregrounds[:, j, :, :, :]
|
168 |
+
alpha = alphas[:, j, :, :, :]
|
169 |
+
decision = decisions[:, j, :, :, :]
|
170 |
+
self.rec = foreground * alpha * decision + self.rec * (1 - alpha * decision)
|
171 |
+
|
172 |
+
@staticmethod
|
173 |
+
def get_sigma_sqrt(w, h, theta):
|
174 |
+
sigma_00 = w * (torch.cos(theta) ** 2) / 2 + h * (torch.sin(theta) ** 2) / 2
|
175 |
+
sigma_01 = (w - h) * torch.cos(theta) * torch.sin(theta) / 2
|
176 |
+
sigma_11 = h * (torch.cos(theta) ** 2) / 2 + w * (torch.sin(theta) ** 2) / 2
|
177 |
+
sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
|
178 |
+
sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
|
179 |
+
sigma = torch.stack([sigma_0, sigma_1], dim=-2)
|
180 |
+
return sigma
|
181 |
+
|
182 |
+
@staticmethod
|
183 |
+
def get_sigma(w, h, theta):
|
184 |
+
sigma_00 = w * w * (torch.cos(theta) ** 2) / 4 + h * h * (torch.sin(theta) ** 2) / 4
|
185 |
+
sigma_01 = (w * w - h * h) * torch.cos(theta) * torch.sin(theta) / 4
|
186 |
+
sigma_11 = h * h * (torch.cos(theta) ** 2) / 4 + w * w * (torch.sin(theta) ** 2) / 4
|
187 |
+
sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
|
188 |
+
sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
|
189 |
+
sigma = torch.stack([sigma_0, sigma_1], dim=-2)
|
190 |
+
return sigma
|
191 |
+
|
192 |
+
def gaussian_w_distance(self, param_1, param_2):
|
193 |
+
mu_1, w_1, h_1, theta_1 = torch.split(param_1, (2, 1, 1, 1), dim=-1)
|
194 |
+
w_1 = w_1.squeeze(-1)
|
195 |
+
h_1 = h_1.squeeze(-1)
|
196 |
+
theta_1 = torch.acos(torch.tensor(-1., device=param_1.device)) * theta_1.squeeze(-1)
|
197 |
+
trace_1 = (w_1 ** 2 + h_1 ** 2) / 4
|
198 |
+
mu_2, w_2, h_2, theta_2 = torch.split(param_2, (2, 1, 1, 1), dim=-1)
|
199 |
+
w_2 = w_2.squeeze(-1)
|
200 |
+
h_2 = h_2.squeeze(-1)
|
201 |
+
theta_2 = torch.acos(torch.tensor(-1., device=param_2.device)) * theta_2.squeeze(-1)
|
202 |
+
trace_2 = (w_2 ** 2 + h_2 ** 2) / 4
|
203 |
+
sigma_1_sqrt = self.get_sigma_sqrt(w_1, h_1, theta_1)
|
204 |
+
sigma_2 = self.get_sigma(w_2, h_2, theta_2)
|
205 |
+
trace_12 = torch.matmul(torch.matmul(sigma_1_sqrt, sigma_2), sigma_1_sqrt)
|
206 |
+
trace_12 = torch.sqrt(trace_12[..., 0, 0] + trace_12[..., 1, 1] + 2 * torch.sqrt(
|
207 |
+
trace_12[..., 0, 0] * trace_12[..., 1, 1] - trace_12[..., 0, 1] * trace_12[..., 1, 0]))
|
208 |
+
return torch.sum((mu_1 - mu_2) ** 2, dim=-1) + trace_1 + trace_2 - 2 * trace_12
|
209 |
+
|
210 |
+
def optimize_parameters(self):
|
211 |
+
self.forward()
|
212 |
+
self.loss_pixel = self.criterion_pixel(self.rec, self.render) * self.opt.lambda_pixel
|
213 |
+
cur_valid_gt_size = 0
|
214 |
+
with torch.no_grad():
|
215 |
+
r_idx = []
|
216 |
+
c_idx = []
|
217 |
+
for i in range(self.gt_param.shape[0]):
|
218 |
+
is_valid_gt = self.gt_decision[i].bool()
|
219 |
+
valid_gt_param = self.gt_param[i, is_valid_gt]
|
220 |
+
cost_matrix_l1 = torch.cdist(self.pred_param[i], valid_gt_param, p=1)
|
221 |
+
pred_param_broad = self.pred_param[i].unsqueeze(1).contiguous().repeat(
|
222 |
+
1, valid_gt_param.shape[0], 1)
|
223 |
+
valid_gt_param_broad = valid_gt_param.unsqueeze(0).contiguous().repeat(
|
224 |
+
self.pred_param.shape[1], 1, 1)
|
225 |
+
cost_matrix_w = self.gaussian_w_distance(pred_param_broad, valid_gt_param_broad)
|
226 |
+
decision = self.pred_decision[i]
|
227 |
+
cost_matrix_decision = (1 - decision).unsqueeze(-1).repeat(1, valid_gt_param.shape[0])
|
228 |
+
r, c = linear_sum_assignment((cost_matrix_l1 + cost_matrix_w + cost_matrix_decision).cpu())
|
229 |
+
r_idx.append(torch.tensor(r + self.pred_param.shape[1] * i, device=self.device))
|
230 |
+
c_idx.append(torch.tensor(c + cur_valid_gt_size, device=self.device))
|
231 |
+
cur_valid_gt_size += valid_gt_param.shape[0]
|
232 |
+
r_idx = torch.cat(r_idx, dim=0)
|
233 |
+
c_idx = torch.cat(c_idx, dim=0)
|
234 |
+
paired_gt_decision = torch.zeros(self.gt_decision.shape[0] * self.gt_decision.shape[1], device=self.device)
|
235 |
+
paired_gt_decision[r_idx] = 1.
|
236 |
+
all_valid_gt_param = self.gt_param[self.gt_decision.bool(), :]
|
237 |
+
all_pred_param = self.pred_param.view(-1, self.pred_param.shape[2]).contiguous()
|
238 |
+
all_pred_decision = self.pred_decision.view(-1).contiguous()
|
239 |
+
paired_gt_param = all_valid_gt_param[c_idx, :]
|
240 |
+
paired_pred_param = all_pred_param[r_idx, :]
|
241 |
+
self.loss_gt = self.criterion_pixel(paired_pred_param, paired_gt_param) * self.opt.lambda_gt
|
242 |
+
self.loss_w = self.gaussian_w_distance(paired_pred_param, paired_gt_param).mean() * self.opt.lambda_w
|
243 |
+
self.loss_decision = self.criterion_decision(all_pred_decision, paired_gt_decision) * self.opt.lambda_decision
|
244 |
+
loss = self.loss_pixel + self.loss_gt + self.loss_w + self.loss_decision
|
245 |
+
loss.backward()
|
246 |
+
self.optimizer.step()
|
247 |
+
self.optimizer.zero_grad()
|
train/options/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
train/options/base_options.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from util import util
|
4 |
+
import torch
|
5 |
+
import models
|
6 |
+
import data
|
7 |
+
|
8 |
+
|
9 |
+
class BaseOptions:
|
10 |
+
"""This class defines options used during both training and test time.
|
11 |
+
|
12 |
+
It also implements several helper functions such as parsing, printing, and saving the options.
|
13 |
+
It also gathers additional options defined in <modify_commandline_options> functions
|
14 |
+
in both dataset class and model class.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
"""Reset the class; indicates the class hasn't been initialized"""
|
19 |
+
self.initialized = False
|
20 |
+
|
21 |
+
def initialize(self, parser):
|
22 |
+
"""Define the common options that are used in both training and test."""
|
23 |
+
# basic parameters
|
24 |
+
parser.add_argument('--dataroot', default='.',
|
25 |
+
help='path to images (should have sub-folders trainA, trainB, valA, valB, etc)')
|
26 |
+
parser.add_argument('--name', type=str, default='experiment_name',
|
27 |
+
help='name of the experiment. It decides where to store samples and models')
|
28 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
29 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
30 |
+
# model parameters
|
31 |
+
parser.add_argument('--model', type=str, default='painter',
|
32 |
+
help='chooses which model to use.')
|
33 |
+
parser.add_argument('--input_nc', type=int, default=3,
|
34 |
+
help='# of input image channels: 3 for RGB and 1 for grayscale')
|
35 |
+
parser.add_argument('--output_nc', type=int, default=3,
|
36 |
+
help='# of output image channels: 3 for RGB and 1 for grayscale')
|
37 |
+
parser.add_argument('--ngf', type=int, default=256, help='# of gen filters in the first conv layer')
|
38 |
+
parser.add_argument('--layer_num', type=int, default=2, help='# of resnet block for generator')
|
39 |
+
parser.add_argument('--init_type', type=str, default='normal',
|
40 |
+
help='network initialization [normal | xavier | kaiming | orthogonal]')
|
41 |
+
parser.add_argument('--init_gain', type=float, default=0.02,
|
42 |
+
help='scaling factor for normal, xavier and orthogonal.')
|
43 |
+
# dataset parameters
|
44 |
+
parser.add_argument('--dataset_mode', type=str, default='single',
|
45 |
+
help='chooses how datasets are loaded.')
|
46 |
+
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
|
47 |
+
parser.add_argument('--serial_batches', action='store_true',
|
48 |
+
help='if true, takes images in order to make batches, otherwise takes them randomly')
|
49 |
+
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
50 |
+
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
51 |
+
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
|
52 |
+
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
53 |
+
parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
|
54 |
+
help='Maximum number of samples allowed per dataset. If the dataset directory contains '
|
55 |
+
'more than max_dataset_size, only a subset is loaded.')
|
56 |
+
parser.add_argument('--preprocess', type=str, default='resize_and_crop',
|
57 |
+
help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | '
|
58 |
+
'scale_width_and_crop | none]')
|
59 |
+
parser.add_argument('--no_flip', action='store_true',
|
60 |
+
help='if specified, do not flip the images for data augmentation')
|
61 |
+
parser.add_argument('--display_winsize', type=int, default=256,
|
62 |
+
help='display window size for both visdom and HTML')
|
63 |
+
# additional parameters
|
64 |
+
parser.add_argument('--epoch', type=str, default='latest',
|
65 |
+
help='which epoch to load? set to latest to use latest cached model')
|
66 |
+
parser.add_argument('--load_iter', type=int, default='0',
|
67 |
+
help='which iteration to load? if load_iter > 0, the code will load models by iter_['
|
68 |
+
'load_iter]; otherwise, the code will load models by [epoch]')
|
69 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
70 |
+
parser.add_argument('--suffix', default='', type=str,
|
71 |
+
help='customized suffix: opt.name = opt.name + suffix')
|
72 |
+
self.initialized = True
|
73 |
+
return parser
|
74 |
+
|
75 |
+
def gather_options(self):
|
76 |
+
"""Initialize our parser with basic options(only once).
|
77 |
+
Add additional model-specific and dataset-specific options.
|
78 |
+
These options are defined in the <modify_commandline_options> function
|
79 |
+
in model and dataset classes.
|
80 |
+
"""
|
81 |
+
if not self.initialized: # check if it has been initialized
|
82 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
83 |
+
parser = self.initialize(parser)
|
84 |
+
|
85 |
+
# get the basic options
|
86 |
+
opt, _ = parser.parse_known_args()
|
87 |
+
|
88 |
+
# modify model-related parser options
|
89 |
+
model_name = opt.model
|
90 |
+
model_option_setter = models.get_option_setter(model_name)
|
91 |
+
parser = model_option_setter(parser, self.isTrain)
|
92 |
+
opt, _ = parser.parse_known_args() # parse again with new defaults
|
93 |
+
|
94 |
+
# modify dataset-related parser options
|
95 |
+
dataset_name = opt.dataset_mode
|
96 |
+
dataset_option_setter = data.get_option_setter(dataset_name)
|
97 |
+
parser = dataset_option_setter(parser, self.isTrain)
|
98 |
+
|
99 |
+
# save and return the parser
|
100 |
+
self.parser = parser
|
101 |
+
return parser.parse_args()
|
102 |
+
|
103 |
+
def print_options(self, opt):
|
104 |
+
"""Print and save options
|
105 |
+
|
106 |
+
It will print both current options and default values(if different).
|
107 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
108 |
+
"""
|
109 |
+
message = ''
|
110 |
+
message += '----------------- Options ---------------\n'
|
111 |
+
for k, v in sorted(vars(opt).items()):
|
112 |
+
comment = ''
|
113 |
+
default = self.parser.get_default(k)
|
114 |
+
if v != default:
|
115 |
+
comment = '\t[default: %s]' % str(default)
|
116 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
117 |
+
message += '----------------- End -------------------'
|
118 |
+
print(message)
|
119 |
+
|
120 |
+
# save to the disk
|
121 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
122 |
+
util.mkdirs(expr_dir)
|
123 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
124 |
+
with open(file_name, 'wt') as opt_file:
|
125 |
+
opt_file.write(message)
|
126 |
+
opt_file.write('\n')
|
127 |
+
|
128 |
+
def parse(self):
|
129 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
130 |
+
opt = self.gather_options()
|
131 |
+
opt.isTrain = self.isTrain # train or test
|
132 |
+
|
133 |
+
# process opt.suffix
|
134 |
+
if opt.suffix:
|
135 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
136 |
+
opt.name = opt.name + suffix
|
137 |
+
|
138 |
+
self.print_options(opt)
|
139 |
+
|
140 |
+
# set gpu ids
|
141 |
+
str_ids = opt.gpu_ids.split(',')
|
142 |
+
opt.gpu_ids = []
|
143 |
+
for str_id in str_ids:
|
144 |
+
id = int(str_id)
|
145 |
+
if id >= 0:
|
146 |
+
opt.gpu_ids.append(id)
|
147 |
+
if len(opt.gpu_ids) > 0:
|
148 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
149 |
+
|
150 |
+
self.opt = opt
|
151 |
+
return self.opt
|
train/options/test_options.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TestOptions(BaseOptions):
|
5 |
+
"""This class includes test options.
|
6 |
+
|
7 |
+
It also includes shared options defined in BaseOptions.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def initialize(self, parser):
|
11 |
+
parser = BaseOptions.initialize(self, parser) # define shared options
|
12 |
+
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
13 |
+
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
14 |
+
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
15 |
+
# Dropout and Batch norm has different behavior during training and test.
|
16 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
17 |
+
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
18 |
+
# rewrite devalue values
|
19 |
+
parser.set_defaults(model='test')
|
20 |
+
# To avoid cropping, the load_size should be the same as crop_size
|
21 |
+
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
22 |
+
self.isTrain = False
|
23 |
+
return parser
|
train/options/train_options.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TrainOptions(BaseOptions):
|
5 |
+
"""This class includes training options.
|
6 |
+
|
7 |
+
It also includes shared options defined in BaseOptions.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def initialize(self, parser):
|
11 |
+
parser = BaseOptions.initialize(self, parser)
|
12 |
+
# visdom and HTML visualization parameters
|
13 |
+
parser.add_argument('--display_freq', type=int, default=40,
|
14 |
+
help='frequency of showing training results on screen')
|
15 |
+
parser.add_argument('--display_ncols', type=int, default=4,
|
16 |
+
help='if positive, display all images in a single visdom web panel '
|
17 |
+
'with certain number of images per row.')
|
18 |
+
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
|
19 |
+
parser.add_argument('--display_server', type=str, default="http://localhost",
|
20 |
+
help='visdom server of the web display')
|
21 |
+
parser.add_argument('--display_env', type=str, default='main',
|
22 |
+
help='visdom display environment name (default is "main")')
|
23 |
+
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
24 |
+
parser.add_argument('--update_html_freq', type=int, default=1000,
|
25 |
+
help='frequency of saving training results to html')
|
26 |
+
parser.add_argument('--print_freq', type=int, default=10,
|
27 |
+
help='frequency of showing training results on console')
|
28 |
+
parser.add_argument('--no_html', action='store_true',
|
29 |
+
help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
30 |
+
# network saving and loading parameters
|
31 |
+
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
32 |
+
parser.add_argument('--save_epoch_freq', type=int, default=5,
|
33 |
+
help='frequency of saving checkpoints at the end of epochs')
|
34 |
+
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
|
35 |
+
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
36 |
+
parser.add_argument('--epoch_count', type=int, default=1,
|
37 |
+
help='the starting epoch count, we save the model '
|
38 |
+
'by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
39 |
+
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
40 |
+
# training parameters
|
41 |
+
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
|
42 |
+
parser.add_argument('--n_epochs_decay', type=int, default=100,
|
43 |
+
help='number of epochs to linearly decay learning rate to zero')
|
44 |
+
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
45 |
+
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
46 |
+
parser.add_argument('--lr_policy', type=str, default='linear',
|
47 |
+
help='learning rate policy. [linear | step | plateau | cosine]')
|
48 |
+
parser.add_argument('--lr_decay_iters', type=int, default=50,
|
49 |
+
help='multiply by a gamma every lr_decay_iters iterations')
|
50 |
+
|
51 |
+
self.isTrain = True
|
52 |
+
return parser
|
train/train.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from options.train_options import TrainOptions
|
3 |
+
from data import create_dataset
|
4 |
+
from models import create_model
|
5 |
+
from util.visualizer import Visualizer
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
opt = TrainOptions().parse() # get training options
|
9 |
+
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
10 |
+
dataset_size = len(dataset) # get the number of images in the dataset.
|
11 |
+
print('The number of training images = %d' % dataset_size)
|
12 |
+
|
13 |
+
model = create_model(opt) # create a model given opt.model and other options
|
14 |
+
model.setup(opt) # regular setup: load and print networks; create schedulers
|
15 |
+
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
|
16 |
+
total_iters = 0 # the total number of training iterations
|
17 |
+
|
18 |
+
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
|
19 |
+
epoch_start_time = time.time() # timer for entire epoch
|
20 |
+
iter_data_time = time.time() # timer for data loading per iteration
|
21 |
+
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
|
22 |
+
visualizer.reset() # reset visualizer: make sure it saves results to HTML at least once every epoch
|
23 |
+
for i, data in enumerate(dataset): # inner loop within one epoch
|
24 |
+
iter_start_time = time.time() # timer for computation per iteration
|
25 |
+
if total_iters % opt.print_freq == 0:
|
26 |
+
t_data = iter_start_time - iter_data_time
|
27 |
+
|
28 |
+
total_iters += opt.batch_size
|
29 |
+
epoch_iter += opt.batch_size
|
30 |
+
model.set_input(data) # unpack data from dataset and apply preprocessing
|
31 |
+
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
32 |
+
|
33 |
+
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
|
34 |
+
save_result = total_iters % opt.update_html_freq == 0
|
35 |
+
model.compute_visuals()
|
36 |
+
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
|
37 |
+
|
38 |
+
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
|
39 |
+
losses = model.get_current_losses()
|
40 |
+
t_comp = (time.time() - iter_start_time) / opt.batch_size
|
41 |
+
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
|
42 |
+
if opt.display_id > 0:
|
43 |
+
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
|
44 |
+
|
45 |
+
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
|
46 |
+
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
|
47 |
+
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
|
48 |
+
model.save_networks(save_suffix)
|
49 |
+
|
50 |
+
iter_data_time = time.time()
|
51 |
+
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
|
52 |
+
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
|
53 |
+
model.save_networks('latest')
|
54 |
+
model.save_networks(epoch)
|
55 |
+
|
56 |
+
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay,
|
57 |
+
time.time() - epoch_start_time))
|
58 |
+
model.update_learning_rate() # update learning rates in the beginning of every epoch.
|
train/train.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python train.py \
|
2 |
+
--name painter \
|
3 |
+
--gpu_ids 0 \
|
4 |
+
--model painter \
|
5 |
+
--dataset_mode null \
|
6 |
+
--batch_size 64 \
|
7 |
+
--display_freq 25 \
|
8 |
+
--print_freq 25 \
|
9 |
+
--lr 1e-4 \
|
10 |
+
--init_type normal \
|
11 |
+
--n_epochs 200 \
|
12 |
+
--n_epochs_decay 20 \
|
13 |
+
--max_dataset_size 16384 \
|
14 |
+
--save_epoch_freq 20
|
train/util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
train/util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
train/util/morphology.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Erosion2d(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, m=1):
|
9 |
+
super(Erosion2d, self).__init__()
|
10 |
+
self.m = m
|
11 |
+
self.pad = [m, m, m, m]
|
12 |
+
self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
batch_size, c, h, w = x.shape
|
16 |
+
x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
|
17 |
+
for i in range(c):
|
18 |
+
channel = self.unfold(x_pad[:, [i], :, :])
|
19 |
+
channel = torch.min(channel, dim=1, keepdim=True)[0]
|
20 |
+
channel = channel.view([batch_size, 1, h, w])
|
21 |
+
x[:, [i], :, :] = channel
|
22 |
+
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class Dilation2d(nn.Module):
|
27 |
+
|
28 |
+
def __init__(self, m=1):
|
29 |
+
super(Dilation2d, self).__init__()
|
30 |
+
self.m = m
|
31 |
+
self.pad = [m, m, m, m]
|
32 |
+
self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
batch_size, c, h, w = x.shape
|
36 |
+
x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
|
37 |
+
for i in range(c):
|
38 |
+
channel = self.unfold(x_pad[:, [i], :, :])
|
39 |
+
channel = torch.max(channel, dim=1, keepdim=True)[0]
|
40 |
+
channel = channel.view([batch_size, 1, h, w])
|
41 |
+
x[:, [i], :, :] = channel
|
42 |
+
|
43 |
+
return x
|
train/util/util.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def tensor2im(input_image, imtype=np.uint8):
|
10 |
+
""""Converts a Tensor array into a numpy image array.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
input_image (tensor) -- the input image tensor array
|
14 |
+
imtype (type) -- the desired type of the converted numpy array
|
15 |
+
"""
|
16 |
+
if not isinstance(input_image, np.ndarray):
|
17 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
18 |
+
image_tensor = input_image.data
|
19 |
+
else:
|
20 |
+
return input_image
|
21 |
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
22 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
23 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
24 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: transpose and scaling
|
25 |
+
else: # if it is a numpy array
|
26 |
+
image_numpy = input_image * 255.
|
27 |
+
return image_numpy.astype(imtype)
|
28 |
+
|
29 |
+
|
30 |
+
def diagnose_network(net, name='network'):
|
31 |
+
"""Calculate and print the mean of average absolute(gradients)
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
net (torch network) -- Torch network
|
35 |
+
name (str) -- the name of the network
|
36 |
+
"""
|
37 |
+
mean = 0.0
|
38 |
+
count = 0
|
39 |
+
for param in net.parameters():
|
40 |
+
if param.grad is not None:
|
41 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
42 |
+
count += 1
|
43 |
+
if count > 0:
|
44 |
+
mean = mean / count
|
45 |
+
print(name)
|
46 |
+
print(mean)
|
47 |
+
|
48 |
+
|
49 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
50 |
+
"""Save a numpy image to the disk
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
image_numpy (numpy array) -- input numpy array
|
54 |
+
image_path (str) -- the path of the image
|
55 |
+
"""
|
56 |
+
|
57 |
+
image_pil = Image.fromarray(image_numpy)
|
58 |
+
h, w, _ = image_numpy.shape
|
59 |
+
|
60 |
+
if aspect_ratio > 1.0:
|
61 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
62 |
+
if aspect_ratio < 1.0:
|
63 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
64 |
+
image_pil.save(image_path)
|
65 |
+
|
66 |
+
|
67 |
+
def print_numpy(x, val=True, shp=False):
|
68 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
val (bool) -- if print the values of the numpy array
|
72 |
+
shp (bool) -- if print the shape of the numpy array
|
73 |
+
"""
|
74 |
+
x = x.astype(np.float64)
|
75 |
+
if shp:
|
76 |
+
print('shape,', x.shape)
|
77 |
+
if val:
|
78 |
+
x = x.flatten()
|
79 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
80 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
81 |
+
|
82 |
+
|
83 |
+
def mkdirs(paths):
|
84 |
+
"""create empty directories if they don't exist
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
paths (str list) -- a list of directory paths
|
88 |
+
"""
|
89 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
90 |
+
for path in paths:
|
91 |
+
mkdir(path)
|
92 |
+
else:
|
93 |
+
mkdir(paths)
|
94 |
+
|
95 |
+
|
96 |
+
def mkdir(path):
|
97 |
+
"""create a single empty directory if it didn't exist
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
path (str) -- a single directory path
|
101 |
+
"""
|
102 |
+
if not os.path.exists(path):
|
103 |
+
os.makedirs(path)
|
train/util/visualizer.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
|
9 |
+
|
10 |
+
if sys.version_info[0] == 2:
|
11 |
+
VisdomExceptionBase = Exception
|
12 |
+
else:
|
13 |
+
VisdomExceptionBase = ConnectionError
|
14 |
+
|
15 |
+
|
16 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
17 |
+
"""Save images to the disk.
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
21 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
22 |
+
image_path (str) -- the string is used to create image paths
|
23 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
24 |
+
width (int) -- the images will be resized to width x width
|
25 |
+
|
26 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
27 |
+
"""
|
28 |
+
image_dir = webpage.get_image_dir()
|
29 |
+
short_path = ntpath.basename(image_path[0])
|
30 |
+
name = os.path.splitext(short_path)[0]
|
31 |
+
|
32 |
+
webpage.add_header(name)
|
33 |
+
ims, txts, links = [], [], []
|
34 |
+
|
35 |
+
for label, im_data in visuals.items():
|
36 |
+
im = util.tensor2im(im_data)
|
37 |
+
image_name = '%s_%s.png' % (name, label)
|
38 |
+
save_path = os.path.join(image_dir, image_name)
|
39 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
40 |
+
ims.append(image_name)
|
41 |
+
txts.append(label)
|
42 |
+
links.append(image_name)
|
43 |
+
webpage.add_images(ims, txts, links, width=width)
|
44 |
+
|
45 |
+
|
46 |
+
class Visualizer:
|
47 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
48 |
+
|
49 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating
|
50 |
+
HTML files with images.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, opt):
|
54 |
+
"""Initialize the Visualizer class
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
58 |
+
Step 1: Cache the training/test options
|
59 |
+
Step 2: connect to a visdom server
|
60 |
+
Step 3: create an HTML object for saveing HTML filters
|
61 |
+
Step 4: create a logging file to store training losses
|
62 |
+
"""
|
63 |
+
self.opt = opt # cache the option
|
64 |
+
self.display_id = opt.display_id
|
65 |
+
self.use_html = opt.isTrain and not opt.no_html
|
66 |
+
self.win_size = opt.display_winsize
|
67 |
+
self.name = opt.name
|
68 |
+
self.port = opt.display_port
|
69 |
+
self.saved = False
|
70 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
71 |
+
import visdom
|
72 |
+
self.ncols = opt.display_ncols
|
73 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
74 |
+
if not self.vis.check_connection():
|
75 |
+
self.create_visdom_connections()
|
76 |
+
|
77 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under
|
78 |
+
# <checkpoints_dir>/web/images/
|
79 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
80 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
81 |
+
print('create web directory %s...' % self.web_dir)
|
82 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
83 |
+
# create a logging file to store training losses
|
84 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
85 |
+
with open(self.log_name, "a") as log_file:
|
86 |
+
now = time.strftime("%c")
|
87 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
88 |
+
|
89 |
+
def reset(self):
|
90 |
+
"""Reset the self.saved status"""
|
91 |
+
self.saved = False
|
92 |
+
|
93 |
+
def create_visdom_connections(self):
|
94 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port <
|
95 |
+
self.port > """
|
96 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
97 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
98 |
+
print('Command: %s' % cmd)
|
99 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
100 |
+
|
101 |
+
def display_current_results(self, visuals, epoch, save_result):
|
102 |
+
"""Display current results on visdom; save current results to an HTML file.
|
103 |
+
|
104 |
+
Parameters:
|
105 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
106 |
+
epoch (int) - - the current epoch
|
107 |
+
save_result (bool) - - if save the current results to an HTML file
|
108 |
+
"""
|
109 |
+
if self.display_id > 0: # show images in the browser using visdom
|
110 |
+
ncols = self.ncols
|
111 |
+
if ncols > 0: # show all the images in one visdom panel
|
112 |
+
ncols = min(ncols, len(visuals))
|
113 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
114 |
+
table_css = """<style>
|
115 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
116 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
117 |
+
</style>""" % (w, h) # create a table css
|
118 |
+
# create a table of images.
|
119 |
+
title = self.name
|
120 |
+
label_html = ''
|
121 |
+
label_html_row = ''
|
122 |
+
images = []
|
123 |
+
idx = 0
|
124 |
+
for label, image in visuals.items():
|
125 |
+
image_numpy = util.tensor2im(image)
|
126 |
+
label_html_row += '<td>%s</td>' % label
|
127 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
128 |
+
idx += 1
|
129 |
+
if idx % ncols == 0:
|
130 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
131 |
+
label_html_row = ''
|
132 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
133 |
+
while idx % ncols != 0:
|
134 |
+
images.append(white_image)
|
135 |
+
label_html_row += '<td></td>'
|
136 |
+
idx += 1
|
137 |
+
if label_html_row != '':
|
138 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
139 |
+
try:
|
140 |
+
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
141 |
+
padding=2, opts=dict(title=title + ' images'))
|
142 |
+
label_html = '<table>%s</table>' % label_html
|
143 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
144 |
+
opts=dict(title=title + ' labels'))
|
145 |
+
except VisdomExceptionBase:
|
146 |
+
self.create_visdom_connections()
|
147 |
+
|
148 |
+
else: # show each image in a separate visdom panel;
|
149 |
+
idx = 1
|
150 |
+
try:
|
151 |
+
for label, image in visuals.items():
|
152 |
+
image_numpy = util.tensor2im(image)
|
153 |
+
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
154 |
+
win=self.display_id + idx)
|
155 |
+
idx += 1
|
156 |
+
except VisdomExceptionBase:
|
157 |
+
self.create_visdom_connections()
|
158 |
+
|
159 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
160 |
+
self.saved = True
|
161 |
+
# save images to the disk
|
162 |
+
for label, image in visuals.items():
|
163 |
+
image_numpy = util.tensor2im(image)
|
164 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
165 |
+
util.save_image(image_numpy, img_path)
|
166 |
+
|
167 |
+
# update website
|
168 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
169 |
+
for n in range(epoch, 0, -1):
|
170 |
+
webpage.add_header('epoch [%d]' % n)
|
171 |
+
ims, txts, links = [], [], []
|
172 |
+
|
173 |
+
for label, image_numpy in visuals.items():
|
174 |
+
image_numpy = util.tensor2im(image)
|
175 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
176 |
+
ims.append(img_path)
|
177 |
+
txts.append(label)
|
178 |
+
links.append(img_path)
|
179 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
180 |
+
webpage.save()
|
181 |
+
|
182 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
183 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
epoch (int) -- current epoch
|
187 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
188 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
189 |
+
"""
|
190 |
+
if not hasattr(self, 'plot_data'):
|
191 |
+
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
192 |
+
self.plot_data['X'].append(epoch + counter_ratio)
|
193 |
+
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
194 |
+
try:
|
195 |
+
self.vis.line(
|
196 |
+
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
197 |
+
Y=np.array(self.plot_data['Y']),
|
198 |
+
opts={
|
199 |
+
'title': self.name + ' loss over time',
|
200 |
+
'legend': self.plot_data['legend'],
|
201 |
+
'xlabel': 'epoch',
|
202 |
+
'ylabel': 'loss'},
|
203 |
+
win=self.display_id)
|
204 |
+
except VisdomExceptionBase:
|
205 |
+
self.create_visdom_connections()
|
206 |
+
|
207 |
+
# losses: same format as |losses| of plot_current_losses
|
208 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
209 |
+
"""print current losses on console; also save the losses to the disk
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
epoch (int) -- current epoch
|
213 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
214 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
215 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
216 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
217 |
+
"""
|
218 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
219 |
+
for k, v in losses.items():
|
220 |
+
message += '%s: %.3f ' % (k, v)
|
221 |
+
|
222 |
+
print(message) # print the message
|
223 |
+
with open(self.log_name, "a") as log_file:
|
224 |
+
log_file.write('%s\n' % message) # save the message
|