Spaces:
Build error
Build error
File size: 38,603 Bytes
b65c5e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 |
# python -u main.py --divider 5.0 --weight_dim 256 --sample 5 --device 0 --num_layers 3 --num_writer 1 --lr 0.001 --VALIDATION 1 --datadir 2 --TYPE_B 0 --TYPE_C 0
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tensorboardX import SummaryWriter
from PIL import Image, ImageDraw, ImageFont
from DataLoader import DataLoader
import pickle
from config.GlobalVariables import *
import os
import argparse
from SynthesisNetwork import SynthesisNetwork
def main(params):
cwds = os.getcwd()
cwd = cwds.split('/')[-1]
divider = params.divider
weight_dim = params.weight_dim
num_samples = params.sample
did = params.device
num_layers = params.num_layers
num_writer = params.num_writer
lr = params.lr
no_char = params.no_char
datadir = './data/writers'
if params.VALIDATION == 1:
VALIDATION = True
else:
VALIDATION = False
if params.CHECKPOINT == 1:
LOAD_FROM_CHECKPOINT = True
else:
LOAD_FROM_CHECKPOINT = False
if params.sentence_loss == 1:
sentence_loss = True
writer_sentence = SummaryWriter(logdir='./runs/sentence-' + cwd)
if VALIDATION:
valid_writer_sentence = SummaryWriter(logdir='./runs/valid-sentence-' + cwd)
else:
sentence_loss = False
if params.word_loss == 1:
word_loss = True
writer_word = SummaryWriter(logdir='./runs/word-' + cwd)
if VALIDATION:
valid_writer_word = SummaryWriter(logdir='./runs/valid-word-' + cwd)
else:
word_loss = False
if params.segment_loss == 1:
segment_loss = True
writer_segment = SummaryWriter(logdir='./runs/segment-' + cwd)
if VALIDATION:
valid_writer_segment = SummaryWriter(logdir='./runs/valid-segment-' + cwd)
else:
segment_loss = False
if params.TYPE_A == 1:
TYPE_A = True
else:
TYPE_A = False
if params.TYPE_B == 1:
TYPE_B = True
else:
TYPE_B = False
if params.TYPE_C == 1:
TYPE_C = True
else:
TYPE_C = False
if params.TYPE_D == 1:
TYPE_D = True
else:
TYPE_D = False
if params.ORIGINAL == 1:
ORIGINAL = True
else:
ORIGINAL = False
if params.REC == 1:
REC = True
else:
REC = False
timestep = 0
grad_clip = 10.0
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.cuda.set_device(did)
else:
num_writer = 1
num_samples = 3
writer_all = SummaryWriter(logdir='./runs/all-'+cwd)
if VALIDATION:
valid_writer_all = SummaryWriter(logdir='./runs/valid-all-'+cwd)
print (sentence_loss, word_loss, segment_loss)
net = SynthesisNetwork(weight_dim=weight_dim, num_layers=num_layers, sentence_loss=sentence_loss, word_loss=word_loss, segment_loss=segment_loss, TYPE_A=TYPE_A, TYPE_B=TYPE_B, TYPE_C=TYPE_C, TYPE_D=TYPE_D, ORIGINAL=ORIGINAL, REC=REC)
_ = net.to(device)
for param in net.parameters():
nn.init.normal_(param, mean=0.0, std=0.075)
dl = DataLoader(num_writer=num_writer, num_samples=num_samples, divider=divider, datadir=datadir)
optimizer = optim.Adam(net.parameters(), lr=lr)
step_size = int(10000 / (num_writer * num_samples))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.99)
if LOAD_FROM_CHECKPOINT:
checkpoints = os.listdir("./model")
max_timestep = max([int(filename[:-3]) for filename in checkpoints if ".pt" in filename])
checkpoint = torch.load(f"./model/{max_timestep}.pt", map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
timestep = checkpoint['timestep']
print(f"Loaded from checkpoint {timestep}")
while True:
optimizer.zero_grad()
timestep += num_writer * num_samples
[all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term,
all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out,
all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length,
all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,
all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='TRAIN')
batch_sentence_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_in]
batch_sentence_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_out]
batch_sentence_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_stroke_length]
batch_sentence_level_term = [torch.FloatTensor(a).to(device) for a in all_sentence_level_term]
batch_sentence_level_char = [torch.LongTensor(a).to(device) for a in all_sentence_level_char]
batch_sentence_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_char_length]
batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in]
batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out]
batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length]
batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term]
batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char]
batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length]
batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in]
batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out]
batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length]
batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term]
batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char]
batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length]
res = net([batch_sentence_level_stroke_in, batch_sentence_level_stroke_out, batch_sentence_level_stroke_length,
batch_sentence_level_term, batch_sentence_level_char, batch_sentence_level_char_length,
batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
batch_segment_level_char, batch_segment_level_char_length])
total_loss, sentence_losses, word_losses, segment_losses = res
print ("Step :", timestep, "\tLoss :", total_loss.item(), "\tlr :", optimizer.param_groups[0]['lr'])
writer_all.add_scalar('ALL/total_loss', total_loss, timestep)
if sentence_loss:
[total_sentence_loss, mean_sentence_W_consistency_loss, mean_ORIGINAL_sentence_termination_loss, mean_ORIGINAL_sentence_loc_reconstruct_loss, mean_ORIGINAL_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_termination_loss, mean_TYPE_A_sentence_loc_reconstruct_loss, mean_TYPE_A_sentence_touch_reconstruct_loss, mean_TYPE_B_sentence_termination_loss, mean_TYPE_B_sentence_loc_reconstruct_loss, mean_TYPE_B_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_WC_reconstruct_loss, mean_TYPE_B_sentence_WC_reconstruct_loss] = sentence_losses
writer_all.add_scalar('ALL/total_sentence_loss', total_sentence_loss, timestep)
writer_sentence.add_scalar('Loss/mean_W_consistency_loss', mean_sentence_W_consistency_loss, timestep)
if ORIGINAL:
writer_sentence.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_sentence_termination_loss + mean_ORIGINAL_sentence_loc_reconstruct_loss + mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_sentence_termination_loss, timestep)
writer_sentence.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_sentence_loc_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
if TYPE_A:
writer_sentence.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_sentence_termination_loss + mean_TYPE_A_sentence_loc_reconstruct_loss + mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_sentence_termination_loss, timestep)
writer_sentence.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_sentence_loc_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_sentence_WC_reconstruct_loss, timestep)
if TYPE_B:
writer_sentence.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_sentence_termination_loss + mean_TYPE_B_sentence_loc_reconstruct_loss + mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_sentence_termination_loss, timestep)
writer_sentence.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_sentence_loc_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_sentence_WC_reconstruct_loss, timestep)
if word_loss:
[total_word_loss, mean_word_W_consistency_loss, mean_ORIGINAL_word_termination_loss, mean_ORIGINAL_word_loc_reconstruct_loss, mean_ORIGINAL_word_touch_reconstruct_loss, mean_TYPE_A_word_termination_loss, mean_TYPE_A_word_loc_reconstruct_loss, mean_TYPE_A_word_touch_reconstruct_loss, mean_TYPE_B_word_termination_loss, mean_TYPE_B_word_loc_reconstruct_loss, mean_TYPE_B_word_touch_reconstruct_loss, mean_TYPE_C_word_termination_loss, mean_TYPE_C_word_loc_reconstruct_loss, mean_TYPE_C_word_touch_reconstruct_loss, mean_TYPE_D_word_termination_loss, mean_TYPE_D_word_loc_reconstruct_loss, mean_TYPE_D_word_touch_reconstruct_loss, mean_TYPE_A_word_WC_reconstruct_loss, mean_TYPE_B_word_WC_reconstruct_loss, mean_TYPE_C_word_WC_reconstruct_loss, mean_TYPE_D_word_WC_reconstruct_loss] = word_losses
writer_all.add_scalar('ALL/total_word_loss', total_word_loss, timestep)
writer_word.add_scalar('Loss/mean_W_consistency_loss', mean_word_W_consistency_loss, timestep)
if ORIGINAL:
writer_word.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_word_termination_loss + mean_ORIGINAL_word_loc_reconstruct_loss + mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_word_termination_loss, timestep)
writer_word.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_word_loc_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
if TYPE_A:
writer_word.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_word_termination_loss + mean_TYPE_A_word_loc_reconstruct_loss + mean_TYPE_A_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_word_termination_loss, timestep)
writer_word.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_word_loc_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_word_WC_reconstruct_loss, timestep)
if TYPE_B:
writer_word.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_word_termination_loss + mean_TYPE_B_word_loc_reconstruct_loss + mean_TYPE_B_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_word_termination_loss, timestep)
writer_word.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_word_loc_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_word_WC_reconstruct_loss, timestep)
if TYPE_C:
writer_word.add_scalar('Loss/mean_TYPE_C_loss', mean_TYPE_C_word_termination_loss + mean_TYPE_C_word_loc_reconstruct_loss + mean_TYPE_C_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_C_termination_loss', mean_TYPE_C_word_termination_loss, timestep)
writer_word.add_scalar('Loss_Loc/mean_TYPE_C_loc_reconstruct_loss', mean_TYPE_C_word_loc_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_C_touch_reconstruct_loss', mean_TYPE_C_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_C_WC_reconstruct_loss', mean_TYPE_C_word_WC_reconstruct_loss, timestep)
if TYPE_D:
writer_word.add_scalar('Loss/mean_TYPE_D_loss', mean_TYPE_D_word_termination_loss + mean_TYPE_D_word_loc_reconstruct_loss + mean_TYPE_D_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_D_termination_loss', mean_TYPE_D_word_termination_loss, timestep)
writer_word.add_scalar('Loss_Loc/mean_TYPE_D_loc_reconstruct_loss', mean_TYPE_D_word_loc_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_D_touch_reconstruct_loss', mean_TYPE_D_word_touch_reconstruct_loss, timestep)
writer_word.add_scalar('Z_LOSS/mean_TYPE_D_WC_reconstruct_loss', mean_TYPE_D_word_WC_reconstruct_loss, timestep)
if segment_loss:
[total_segment_loss, mean_segment_W_consistency_loss, mean_ORIGINAL_segment_termination_loss, mean_ORIGINAL_segment_loc_reconstruct_loss, mean_ORIGINAL_segment_touch_reconstruct_loss, mean_TYPE_A_segment_termination_loss, mean_TYPE_A_segment_loc_reconstruct_loss, mean_TYPE_A_segment_touch_reconstruct_loss, mean_TYPE_B_segment_termination_loss, mean_TYPE_B_segment_loc_reconstruct_loss, mean_TYPE_B_segment_touch_reconstruct_loss, mean_TYPE_A_segment_WC_reconstruct_loss, mean_TYPE_B_segment_WC_reconstruct_loss] = segment_losses
writer_all.add_scalar('ALL/total_segment_loss', total_segment_loss, timestep)
writer_segment.add_scalar('Loss/mean_W_consistency_loss', mean_segment_W_consistency_loss, timestep)
if ORIGINAL:
writer_segment.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_segment_termination_loss + mean_ORIGINAL_segment_loc_reconstruct_loss + mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_segment_termination_loss, timestep)
writer_segment.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_segment_loc_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
if TYPE_A:
writer_segment.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_segment_termination_loss + mean_TYPE_A_segment_loc_reconstruct_loss + mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_segment_termination_loss, timestep)
writer_segment.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_segment_loc_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_segment_WC_reconstruct_loss, timestep)
if TYPE_B:
writer_segment.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_segment_termination_loss + mean_TYPE_B_segment_loc_reconstruct_loss + mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_segment_termination_loss, timestep)
writer_segment.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_segment_loc_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_segment_WC_reconstruct_loss, timestep)
total_loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
for p in net.parameters():
if p.grad is not None:
# p.data.add_(-lr, p.grad.data)
p.data.add_(p.grad.data, alpha=-lr)
optimizer.step()
if timestep % (num_writer * num_samples * 1) == 0.0:
commands_list = net.sample([ batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
batch_segment_level_char, batch_segment_level_char_length])
[t_commands, o_commands, a_commands, b_commands, c_commands, d_commands] = commands_list
t_im = Image.fromarray(np.zeros([160, 750]))
t_dr = ImageDraw.Draw(t_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(t_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
t_dr.line((px,py,x,y),255,1)
px, py = x, y
o_im = Image.fromarray(np.zeros([160, 750]))
o_dr = ImageDraw.Draw(o_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(o_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
o_dr.line((px,py,x,y),255,1)
px, py = x, y
a_im = Image.fromarray(np.zeros([160, 750]))
a_dr = ImageDraw.Draw(a_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(a_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
a_dr.line((px,py,x,y),255,1)
px, py = x, y
b_im = Image.fromarray(np.zeros([160, 750]))
b_dr = ImageDraw.Draw(b_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(b_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
b_dr.line((px,py,x,y),255,1)
px, py = x, y
c_im = Image.fromarray(np.zeros([160, 750]))
c_dr = ImageDraw.Draw(c_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(c_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
c_dr.line((px,py,x,y),255,1)
px, py = x, y
d_im = Image.fromarray(np.zeros([160, 750]))
d_dr = ImageDraw.Draw(d_im)
px, py = 30, 100
for i, [dx,dy,t] in enumerate(d_commands):
x = px + dx * 5
y = py + dy * 5
if t == 0:
d_dr.line((px,py,x,y),255,1)
px, py = x, y
dst = Image.new('RGB', (750, 960))
dst.paste(t_im, (0, 0))
dst.paste(o_im, (0, 160))
dst.paste(a_im, (0, 320))
dst.paste(b_im, (0, 480))
dst.paste(c_im, (0, 640))
dst.paste(d_im, (0, 800))
writer_all.add_image('Res/Results', np.asarray(dst.convert("RGB")), timestep, dataformats='HWC')
if VALIDATION:
[all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term,
all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out,
all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length,
all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,
all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='VALID')
batch_sentence_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_in]
batch_sentence_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_out]
batch_sentence_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_stroke_length]
batch_sentence_level_term = [torch.FloatTensor(a).to(device) for a in all_sentence_level_term]
batch_sentence_level_char = [torch.LongTensor(a).to(device) for a in all_sentence_level_char]
batch_sentence_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_char_length]
batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in]
batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out]
batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length]
batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term]
batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char]
batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length]
batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in]
batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out]
batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length]
batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term]
batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char]
batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length]
res = net([batch_sentence_level_stroke_in, batch_sentence_level_stroke_out, batch_sentence_level_stroke_length,
batch_sentence_level_term, batch_sentence_level_char, batch_sentence_level_char_length,
batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
batch_segment_level_char, batch_segment_level_char_length])
total_loss, sentence_losses, word_losses, segment_losses = res
valid_writer_all.add_scalar('ALL/total_loss', total_loss, timestep)
if sentence_loss:
[total_sentence_loss, mean_sentence_W_consistency_loss, mean_ORIGINAL_sentence_termination_loss, mean_ORIGINAL_sentence_loc_reconstruct_loss, mean_ORIGINAL_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_termination_loss, mean_TYPE_A_sentence_loc_reconstruct_loss, mean_TYPE_A_sentence_touch_reconstruct_loss, mean_TYPE_B_sentence_termination_loss, mean_TYPE_B_sentence_loc_reconstruct_loss, mean_TYPE_B_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_WC_reconstruct_loss, mean_TYPE_B_sentence_WC_reconstruct_loss] = sentence_losses
valid_writer_all.add_scalar('ALL/total_sentence_loss', total_sentence_loss, timestep)
valid_writer_sentence.add_scalar('Loss/mean_W_consistency_loss', mean_sentence_W_consistency_loss, timestep)
if ORIGINAL:
valid_writer_sentence.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_sentence_termination_loss + mean_ORIGINAL_sentence_loc_reconstruct_loss + mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_sentence_termination_loss, timestep)
valid_writer_sentence.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_sentence_loc_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
if TYPE_A:
valid_writer_sentence.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_sentence_termination_loss + mean_TYPE_A_sentence_loc_reconstruct_loss + mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_sentence_termination_loss, timestep)
valid_writer_sentence.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_sentence_loc_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_sentence_WC_reconstruct_loss, timestep)
if TYPE_B:
valid_writer_sentence.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_sentence_termination_loss + mean_TYPE_B_sentence_loc_reconstruct_loss + mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_sentence_termination_loss, timestep)
valid_writer_sentence.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_sentence_loc_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_sentence_WC_reconstruct_loss, timestep)
if word_loss:
[total_word_loss, mean_word_W_consistency_loss, mean_ORIGINAL_word_termination_loss, mean_ORIGINAL_word_loc_reconstruct_loss, mean_ORIGINAL_word_touch_reconstruct_loss, mean_TYPE_A_word_termination_loss, mean_TYPE_A_word_loc_reconstruct_loss, mean_TYPE_A_word_touch_reconstruct_loss, mean_TYPE_B_word_termination_loss, mean_TYPE_B_word_loc_reconstruct_loss, mean_TYPE_B_word_touch_reconstruct_loss, mean_TYPE_C_word_termination_loss, mean_TYPE_C_word_loc_reconstruct_loss, mean_TYPE_C_word_touch_reconstruct_loss, mean_TYPE_D_word_termination_loss, mean_TYPE_D_word_loc_reconstruct_loss, mean_TYPE_D_word_touch_reconstruct_loss, mean_TYPE_A_word_WC_reconstruct_loss, mean_TYPE_B_word_WC_reconstruct_loss, mean_TYPE_C_word_WC_reconstruct_loss, mean_TYPE_D_word_WC_reconstruct_loss] = word_losses
valid_writer_all.add_scalar('ALL/total_word_loss', total_word_loss, timestep)
valid_writer_word.add_scalar('Loss/mean_W_consistency_loss', mean_word_W_consistency_loss, timestep)
if ORIGINAL:
valid_writer_word.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_word_termination_loss + mean_ORIGINAL_word_loc_reconstruct_loss + mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_word_termination_loss, timestep)
valid_writer_word.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_word_loc_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
if TYPE_A:
valid_writer_word.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_word_termination_loss + mean_TYPE_A_word_loc_reconstruct_loss + mean_TYPE_A_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_word_termination_loss, timestep)
valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_word_loc_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_word_WC_reconstruct_loss, timestep)
if TYPE_B:
valid_writer_word.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_word_termination_loss + mean_TYPE_B_word_loc_reconstruct_loss + mean_TYPE_B_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_word_termination_loss, timestep)
valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_word_loc_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_word_WC_reconstruct_loss, timestep)
if TYPE_C:
valid_writer_word.add_scalar('Loss/mean_TYPE_C_loss', mean_TYPE_C_word_termination_loss + mean_TYPE_C_word_loc_reconstruct_loss + mean_TYPE_C_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_termination_loss', mean_TYPE_C_word_termination_loss, timestep)
valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_C_loc_reconstruct_loss', mean_TYPE_C_word_loc_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_touch_reconstruct_loss', mean_TYPE_C_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_WC_reconstruct_loss', mean_TYPE_C_word_WC_reconstruct_loss, timestep)
if TYPE_D:
valid_writer_word.add_scalar('Loss/mean_TYPE_D_loss', mean_TYPE_D_word_termination_loss + mean_TYPE_D_word_loc_reconstruct_loss + mean_TYPE_D_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_termination_loss', mean_TYPE_D_word_termination_loss, timestep)
valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_D_loc_reconstruct_loss', mean_TYPE_D_word_loc_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_touch_reconstruct_loss', mean_TYPE_D_word_touch_reconstruct_loss, timestep)
valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_WC_reconstruct_loss', mean_TYPE_D_word_WC_reconstruct_loss, timestep)
if segment_loss:
[total_segment_loss, mean_segment_W_consistency_loss, mean_ORIGINAL_segment_termination_loss, mean_ORIGINAL_segment_loc_reconstruct_loss, mean_ORIGINAL_segment_touch_reconstruct_loss, mean_TYPE_A_segment_termination_loss, mean_TYPE_A_segment_loc_reconstruct_loss, mean_TYPE_A_segment_touch_reconstruct_loss, mean_TYPE_B_segment_termination_loss, mean_TYPE_B_segment_loc_reconstruct_loss, mean_TYPE_B_segment_touch_reconstruct_loss, mean_TYPE_A_segment_WC_reconstruct_loss, mean_TYPE_B_segment_WC_reconstruct_loss] = segment_losses
valid_writer_all.add_scalar('ALL/total_segment_loss', total_segment_loss, timestep)
valid_writer_segment.add_scalar('Loss/mean_W_consistency_loss', mean_segment_W_consistency_loss, timestep)
if ORIGINAL:
valid_writer_segment.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_segment_termination_loss + mean_ORIGINAL_segment_loc_reconstruct_loss + mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_segment_termination_loss, timestep)
valid_writer_segment.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_segment_loc_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
if TYPE_A:
valid_writer_segment.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_segment_termination_loss + mean_TYPE_A_segment_loc_reconstruct_loss + mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_segment_termination_loss, timestep)
valid_writer_segment.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_segment_loc_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_segment_WC_reconstruct_loss, timestep)
if TYPE_B:
valid_writer_segment.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_segment_termination_loss + mean_TYPE_B_segment_loc_reconstruct_loss + mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_segment_termination_loss, timestep)
valid_writer_segment.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_segment_loc_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_segment_WC_reconstruct_loss, timestep)
if timestep % (num_writer * num_samples * 1000) == 0.0:
torch.save({'timestep': timestep,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_loss.item(),
}, 'model/'+str(timestep)+'.pt')
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for training the handwriting synthesis network.')
parser.add_argument('--divider', type=float, default=5.0)
parser.add_argument('--weight_dim', type=int, default=256)
parser.add_argument('--sample', type=int, default=2)
parser.add_argument('--device', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_writer', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--sentence_loss', type=int, default=1)
parser.add_argument('--word_loss', type=int, default=1)
parser.add_argument('--segment_loss', type=int, default=1)
parser.add_argument('--TYPE_A', type=int, default=1)
parser.add_argument('--TYPE_B', type=int, default=1)
parser.add_argument('--TYPE_C', type=int, default=1)
parser.add_argument('--TYPE_D', type=int, default=1)
parser.add_argument('--ORIGINAL', type=int, default=1)
parser.add_argument('--VALIDATION', type=int, default=1)
parser.add_argument('--no_char', type=int, default=0)
parser.add_argument('--REC', type=int, default=1)
parser.add_argument('--CHECKPOINT', type=int, default=0)
main(parser.parse_args())
|