File size: 38,603 Bytes
b65c5e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
# python -u main.py --divider 5.0 --weight_dim 256 --sample 5 --device 0 --num_layers 3 --num_writer 1 --lr 0.001 --VALIDATION 1 --datadir 2 --TYPE_B 0 --TYPE_C 0

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tensorboardX import SummaryWriter
from PIL import Image, ImageDraw, ImageFont
from DataLoader import DataLoader
import pickle
from config.GlobalVariables import *
import os
import argparse
from SynthesisNetwork import SynthesisNetwork

def main(params):
    cwds = os.getcwd()
    cwd = cwds.split('/')[-1]

    divider        = params.divider
    weight_dim    = params.weight_dim
    num_samples    = params.sample
    did            = params.device
    num_layers    = params.num_layers
    num_writer    = params.num_writer
    lr            = params.lr
    no_char        = params.no_char

    datadir = './data/writers'

    if params.VALIDATION == 1:
        VALIDATION = True
    else:
        VALIDATION = False

    if params.CHECKPOINT == 1:
        LOAD_FROM_CHECKPOINT = True
    else:
        LOAD_FROM_CHECKPOINT = False

    if params.sentence_loss == 1:
        sentence_loss = True
        writer_sentence = SummaryWriter(logdir='./runs/sentence-' + cwd)
        if VALIDATION:
            valid_writer_sentence = SummaryWriter(logdir='./runs/valid-sentence-' + cwd)
    else:
        sentence_loss = False
    if params.word_loss == 1:
        word_loss = True
        writer_word = SummaryWriter(logdir='./runs/word-' + cwd)
        if VALIDATION:
            valid_writer_word = SummaryWriter(logdir='./runs/valid-word-' + cwd)
    else:
        word_loss = False
    if params.segment_loss == 1:
        segment_loss = True
        writer_segment = SummaryWriter(logdir='./runs/segment-' + cwd)
        if VALIDATION:
            valid_writer_segment = SummaryWriter(logdir='./runs/valid-segment-' + cwd)
    else:
        segment_loss = False
    if params.TYPE_A == 1:
        TYPE_A = True
    else:
        TYPE_A = False
    if params.TYPE_B == 1:
        TYPE_B = True
    else:
        TYPE_B = False
    if params.TYPE_C == 1:
        TYPE_C = True
    else:
        TYPE_C = False
    if params.TYPE_D == 1:
        TYPE_D = True
    else:
        TYPE_D = False
    if params.ORIGINAL == 1:
        ORIGINAL = True
    else:
        ORIGINAL = False
    if params.REC == 1:
        REC = True
    else:
        REC = False


    timestep        = 0
    grad_clip        = 10.0
    device            = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        torch.cuda.set_device(did)
    else:
        num_writer        = 1
        num_samples        = 3

    writer_all        = SummaryWriter(logdir='./runs/all-'+cwd)
    if VALIDATION:
        valid_writer_all    = SummaryWriter(logdir='./runs/valid-all-'+cwd)

    print (sentence_loss, word_loss, segment_loss)
    net                = SynthesisNetwork(weight_dim=weight_dim, num_layers=num_layers, sentence_loss=sentence_loss, word_loss=word_loss, segment_loss=segment_loss, TYPE_A=TYPE_A, TYPE_B=TYPE_B, TYPE_C=TYPE_C, TYPE_D=TYPE_D, ORIGINAL=ORIGINAL, REC=REC)
    _                = net.to(device)

    for param in net.parameters():
        nn.init.normal_(param, mean=0.0, std=0.075)

    dl                = DataLoader(num_writer=num_writer, num_samples=num_samples, divider=divider, datadir=datadir)

    optimizer        = optim.Adam(net.parameters(), lr=lr)
    step_size        = int(10000 / (num_writer * num_samples))
    scheduler         = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.99)

    if LOAD_FROM_CHECKPOINT:
        checkpoints = os.listdir("./model")
        max_timestep = max([int(filename[:-3]) for filename in checkpoints if ".pt" in filename])
        checkpoint = torch.load(f"./model/{max_timestep}.pt", map_location=torch.device('cpu'))
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        timestep = checkpoint['timestep']
        print(f"Loaded from checkpoint {timestep}")

    while True:
        optimizer.zero_grad()
        timestep           += num_writer * num_samples

        [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term,
        all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out,
        all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length,
        all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,
        all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='TRAIN')

        batch_sentence_level_stroke_in         = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_in]
        batch_sentence_level_stroke_out     = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_out]
        batch_sentence_level_stroke_length     = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_stroke_length]
        batch_sentence_level_term             = [torch.FloatTensor(a).to(device) for a in all_sentence_level_term]
        batch_sentence_level_char             = [torch.LongTensor(a).to(device) for a in all_sentence_level_char]
        batch_sentence_level_char_length     = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_char_length]
        batch_word_level_stroke_in             = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in]
        batch_word_level_stroke_out         = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out]
        batch_word_level_stroke_length         = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length]
        batch_word_level_term                 = [torch.FloatTensor(a).to(device) for a in all_word_level_term]
        batch_word_level_char                 = [torch.LongTensor(a).to(device) for a in all_word_level_char]
        batch_word_level_char_length         = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length]
        batch_segment_level_stroke_in         = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in]
        batch_segment_level_stroke_out         = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out]
        batch_segment_level_stroke_length     = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length]
        batch_segment_level_term             = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term]
        batch_segment_level_char             = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char]
        batch_segment_level_char_length     = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length]

        res = net([batch_sentence_level_stroke_in, batch_sentence_level_stroke_out, batch_sentence_level_stroke_length,
                batch_sentence_level_term, batch_sentence_level_char, batch_sentence_level_char_length,
                batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
                batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
                batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
                batch_segment_level_char, batch_segment_level_char_length])

        total_loss, sentence_losses, word_losses, segment_losses = res

        print ("Step :", timestep, "\tLoss :", total_loss.item(), "\tlr :", optimizer.param_groups[0]['lr'])

        writer_all.add_scalar('ALL/total_loss', total_loss, timestep)

        if sentence_loss:
            [total_sentence_loss, mean_sentence_W_consistency_loss, mean_ORIGINAL_sentence_termination_loss, mean_ORIGINAL_sentence_loc_reconstruct_loss, mean_ORIGINAL_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_termination_loss, mean_TYPE_A_sentence_loc_reconstruct_loss, mean_TYPE_A_sentence_touch_reconstruct_loss, mean_TYPE_B_sentence_termination_loss, mean_TYPE_B_sentence_loc_reconstruct_loss, mean_TYPE_B_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_WC_reconstruct_loss, mean_TYPE_B_sentence_WC_reconstruct_loss] = sentence_losses

            writer_all.add_scalar('ALL/total_sentence_loss', total_sentence_loss, timestep)
            writer_sentence.add_scalar('Loss/mean_W_consistency_loss', mean_sentence_W_consistency_loss, timestep)
            if ORIGINAL:
                writer_sentence.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_sentence_termination_loss + mean_ORIGINAL_sentence_loc_reconstruct_loss + mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_sentence_termination_loss, timestep)
                writer_sentence.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_sentence_loc_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
            if TYPE_A:
                writer_sentence.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_sentence_termination_loss + mean_TYPE_A_sentence_loc_reconstruct_loss + mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_sentence_termination_loss, timestep)
                writer_sentence.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_sentence_loc_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_sentence_WC_reconstruct_loss, timestep)
            if TYPE_B:
                writer_sentence.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_sentence_termination_loss + mean_TYPE_B_sentence_loc_reconstruct_loss + mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_sentence_termination_loss, timestep)
                writer_sentence.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_sentence_loc_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
                writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_sentence_WC_reconstruct_loss, timestep)

        if word_loss:
            [total_word_loss, mean_word_W_consistency_loss, mean_ORIGINAL_word_termination_loss, mean_ORIGINAL_word_loc_reconstruct_loss, mean_ORIGINAL_word_touch_reconstruct_loss, mean_TYPE_A_word_termination_loss, mean_TYPE_A_word_loc_reconstruct_loss, mean_TYPE_A_word_touch_reconstruct_loss, mean_TYPE_B_word_termination_loss, mean_TYPE_B_word_loc_reconstruct_loss, mean_TYPE_B_word_touch_reconstruct_loss, mean_TYPE_C_word_termination_loss, mean_TYPE_C_word_loc_reconstruct_loss, mean_TYPE_C_word_touch_reconstruct_loss, mean_TYPE_D_word_termination_loss, mean_TYPE_D_word_loc_reconstruct_loss, mean_TYPE_D_word_touch_reconstruct_loss, mean_TYPE_A_word_WC_reconstruct_loss, mean_TYPE_B_word_WC_reconstruct_loss, mean_TYPE_C_word_WC_reconstruct_loss, mean_TYPE_D_word_WC_reconstruct_loss] = word_losses
            writer_all.add_scalar('ALL/total_word_loss', total_word_loss, timestep)
            writer_word.add_scalar('Loss/mean_W_consistency_loss', mean_word_W_consistency_loss, timestep)

            if ORIGINAL:
                writer_word.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_word_termination_loss + mean_ORIGINAL_word_loc_reconstruct_loss + mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_word_termination_loss, timestep)
                writer_word.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_word_loc_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
            if TYPE_A:
                writer_word.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_word_termination_loss + mean_TYPE_A_word_loc_reconstruct_loss + mean_TYPE_A_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_word_termination_loss, timestep)
                writer_word.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_word_loc_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_word_WC_reconstruct_loss, timestep)
            if TYPE_B:
                writer_word.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_word_termination_loss + mean_TYPE_B_word_loc_reconstruct_loss + mean_TYPE_B_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_word_termination_loss, timestep)
                writer_word.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_word_loc_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_word_WC_reconstruct_loss, timestep)
            if TYPE_C:
                writer_word.add_scalar('Loss/mean_TYPE_C_loss', mean_TYPE_C_word_termination_loss + mean_TYPE_C_word_loc_reconstruct_loss + mean_TYPE_C_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_C_termination_loss', mean_TYPE_C_word_termination_loss, timestep)
                writer_word.add_scalar('Loss_Loc/mean_TYPE_C_loc_reconstruct_loss', mean_TYPE_C_word_loc_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_C_touch_reconstruct_loss', mean_TYPE_C_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_C_WC_reconstruct_loss', mean_TYPE_C_word_WC_reconstruct_loss, timestep)
            if TYPE_D:
                writer_word.add_scalar('Loss/mean_TYPE_D_loss', mean_TYPE_D_word_termination_loss + mean_TYPE_D_word_loc_reconstruct_loss + mean_TYPE_D_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_D_termination_loss', mean_TYPE_D_word_termination_loss, timestep)
                writer_word.add_scalar('Loss_Loc/mean_TYPE_D_loc_reconstruct_loss', mean_TYPE_D_word_loc_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_D_touch_reconstruct_loss', mean_TYPE_D_word_touch_reconstruct_loss, timestep)
                writer_word.add_scalar('Z_LOSS/mean_TYPE_D_WC_reconstruct_loss', mean_TYPE_D_word_WC_reconstruct_loss, timestep)

        if segment_loss:
            [total_segment_loss, mean_segment_W_consistency_loss, mean_ORIGINAL_segment_termination_loss, mean_ORIGINAL_segment_loc_reconstruct_loss, mean_ORIGINAL_segment_touch_reconstruct_loss, mean_TYPE_A_segment_termination_loss, mean_TYPE_A_segment_loc_reconstruct_loss, mean_TYPE_A_segment_touch_reconstruct_loss, mean_TYPE_B_segment_termination_loss, mean_TYPE_B_segment_loc_reconstruct_loss, mean_TYPE_B_segment_touch_reconstruct_loss, mean_TYPE_A_segment_WC_reconstruct_loss, mean_TYPE_B_segment_WC_reconstruct_loss] = segment_losses
            writer_all.add_scalar('ALL/total_segment_loss', total_segment_loss, timestep)
            writer_segment.add_scalar('Loss/mean_W_consistency_loss', mean_segment_W_consistency_loss, timestep)
            if ORIGINAL:
                writer_segment.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_segment_termination_loss + mean_ORIGINAL_segment_loc_reconstruct_loss + mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_segment_termination_loss, timestep)
                writer_segment.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_segment_loc_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
            if TYPE_A:
                writer_segment.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_segment_termination_loss + mean_TYPE_A_segment_loc_reconstruct_loss + mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_segment_termination_loss, timestep)
                writer_segment.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_segment_loc_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_segment_WC_reconstruct_loss, timestep)
            if TYPE_B:
                writer_segment.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_segment_termination_loss + mean_TYPE_B_segment_loc_reconstruct_loss + mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_segment_termination_loss, timestep)
                writer_segment.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_segment_loc_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
                writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_segment_WC_reconstruct_loss, timestep)

        total_loss.backward()

        torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
        for p in net.parameters():
            if p.grad is not None:
                # p.data.add_(-lr, p.grad.data)
                p.data.add_(p.grad.data, alpha=-lr)

        optimizer.step()

        if timestep % (num_writer * num_samples * 1) == 0.0:
            commands_list = net.sample([    batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
                                            batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
                                            batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
                                            batch_segment_level_char, batch_segment_level_char_length])
            [t_commands, o_commands, a_commands, b_commands, c_commands, d_commands] = commands_list

            t_im = Image.fromarray(np.zeros([160, 750]))
            t_dr = ImageDraw.Draw(t_im)

            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(t_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    t_dr.line((px,py,x,y),255,1)
                px, py = x, y

            o_im = Image.fromarray(np.zeros([160, 750]))
            o_dr = ImageDraw.Draw(o_im)
            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(o_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    o_dr.line((px,py,x,y),255,1)
                px, py = x, y

            a_im = Image.fromarray(np.zeros([160, 750]))
            a_dr = ImageDraw.Draw(a_im)
            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(a_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    a_dr.line((px,py,x,y),255,1)
                px, py = x, y

            b_im = Image.fromarray(np.zeros([160, 750]))
            b_dr = ImageDraw.Draw(b_im)
            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(b_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    b_dr.line((px,py,x,y),255,1)
                px, py = x, y

            c_im = Image.fromarray(np.zeros([160, 750]))
            c_dr = ImageDraw.Draw(c_im)
            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(c_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    c_dr.line((px,py,x,y),255,1)
                px, py = x, y

            d_im = Image.fromarray(np.zeros([160, 750]))
            d_dr = ImageDraw.Draw(d_im)
            px, py = 30, 100
            for i, [dx,dy,t] in enumerate(d_commands):
                x = px + dx * 5
                y = py + dy * 5
                if t == 0:
                    d_dr.line((px,py,x,y),255,1)
                px, py = x, y

            dst = Image.new('RGB', (750, 960))
            dst.paste(t_im, (0, 0))
            dst.paste(o_im, (0, 160))
            dst.paste(a_im, (0, 320))
            dst.paste(b_im, (0, 480))
            dst.paste(c_im, (0, 640))
            dst.paste(d_im, (0, 800))
            writer_all.add_image('Res/Results', np.asarray(dst.convert("RGB")), timestep, dataformats='HWC')

        if VALIDATION:
            [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term,
            all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out,
            all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length,
            all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,
            all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='VALID')

            batch_sentence_level_stroke_in         = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_in]
            batch_sentence_level_stroke_out     = [torch.FloatTensor(a).to(device) for a in all_sentence_level_stroke_out]
            batch_sentence_level_stroke_length     = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_stroke_length]
            batch_sentence_level_term             = [torch.FloatTensor(a).to(device) for a in all_sentence_level_term]
            batch_sentence_level_char             = [torch.LongTensor(a).to(device) for a in all_sentence_level_char]
            batch_sentence_level_char_length     = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_sentence_level_char_length]
            batch_word_level_stroke_in             = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in]
            batch_word_level_stroke_out         = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out]
            batch_word_level_stroke_length         = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length]
            batch_word_level_term                 = [torch.FloatTensor(a).to(device) for a in all_word_level_term]
            batch_word_level_char                 = [torch.LongTensor(a).to(device) for a in all_word_level_char]
            batch_word_level_char_length         = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length]
            batch_segment_level_stroke_in         = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in]
            batch_segment_level_stroke_out         = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out]
            batch_segment_level_stroke_length     = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length]
            batch_segment_level_term             = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term]
            batch_segment_level_char             = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char]
            batch_segment_level_char_length     = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length]

            res = net([batch_sentence_level_stroke_in, batch_sentence_level_stroke_out, batch_sentence_level_stroke_length,
                    batch_sentence_level_term, batch_sentence_level_char, batch_sentence_level_char_length,
                    batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length,
                    batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in,
                    batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term,
                    batch_segment_level_char, batch_segment_level_char_length])

            total_loss, sentence_losses, word_losses, segment_losses = res

            valid_writer_all.add_scalar('ALL/total_loss', total_loss, timestep)

            if sentence_loss:
                [total_sentence_loss, mean_sentence_W_consistency_loss, mean_ORIGINAL_sentence_termination_loss, mean_ORIGINAL_sentence_loc_reconstruct_loss, mean_ORIGINAL_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_termination_loss, mean_TYPE_A_sentence_loc_reconstruct_loss, mean_TYPE_A_sentence_touch_reconstruct_loss, mean_TYPE_B_sentence_termination_loss, mean_TYPE_B_sentence_loc_reconstruct_loss, mean_TYPE_B_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_WC_reconstruct_loss, mean_TYPE_B_sentence_WC_reconstruct_loss] = sentence_losses

                valid_writer_all.add_scalar('ALL/total_sentence_loss', total_sentence_loss, timestep)
                valid_writer_sentence.add_scalar('Loss/mean_W_consistency_loss', mean_sentence_W_consistency_loss, timestep)
                if ORIGINAL:
                    valid_writer_sentence.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_sentence_termination_loss + mean_ORIGINAL_sentence_loc_reconstruct_loss + mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_sentence_termination_loss, timestep)
                    valid_writer_sentence.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_sentence_loc_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_sentence_touch_reconstruct_loss, timestep)
                if TYPE_A:
                    valid_writer_sentence.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_sentence_termination_loss + mean_TYPE_A_sentence_loc_reconstruct_loss + mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_sentence_termination_loss, timestep)
                    valid_writer_sentence.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_sentence_loc_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_sentence_touch_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_sentence_WC_reconstruct_loss, timestep)
                if TYPE_B:
                    valid_writer_sentence.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_sentence_termination_loss + mean_TYPE_B_sentence_loc_reconstruct_loss + mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_sentence_termination_loss, timestep)
                    valid_writer_sentence.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_sentence_loc_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_sentence_touch_reconstruct_loss, timestep)
                    valid_writer_sentence.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_sentence_WC_reconstruct_loss, timestep)

            if word_loss:
                [total_word_loss, mean_word_W_consistency_loss, mean_ORIGINAL_word_termination_loss, mean_ORIGINAL_word_loc_reconstruct_loss, mean_ORIGINAL_word_touch_reconstruct_loss, mean_TYPE_A_word_termination_loss, mean_TYPE_A_word_loc_reconstruct_loss, mean_TYPE_A_word_touch_reconstruct_loss, mean_TYPE_B_word_termination_loss, mean_TYPE_B_word_loc_reconstruct_loss, mean_TYPE_B_word_touch_reconstruct_loss, mean_TYPE_C_word_termination_loss, mean_TYPE_C_word_loc_reconstruct_loss, mean_TYPE_C_word_touch_reconstruct_loss, mean_TYPE_D_word_termination_loss, mean_TYPE_D_word_loc_reconstruct_loss, mean_TYPE_D_word_touch_reconstruct_loss, mean_TYPE_A_word_WC_reconstruct_loss, mean_TYPE_B_word_WC_reconstruct_loss, mean_TYPE_C_word_WC_reconstruct_loss, mean_TYPE_D_word_WC_reconstruct_loss] = word_losses
                valid_writer_all.add_scalar('ALL/total_word_loss', total_word_loss, timestep)
                valid_writer_word.add_scalar('Loss/mean_W_consistency_loss', mean_word_W_consistency_loss, timestep)

                if ORIGINAL:
                    valid_writer_word.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_word_termination_loss + mean_ORIGINAL_word_loc_reconstruct_loss + mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_word_termination_loss, timestep)
                    valid_writer_word.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_word_loc_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_word_touch_reconstruct_loss, timestep)
                if TYPE_A:
                    valid_writer_word.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_word_termination_loss + mean_TYPE_A_word_loc_reconstruct_loss + mean_TYPE_A_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_word_termination_loss, timestep)
                    valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_word_loc_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_word_WC_reconstruct_loss, timestep)
                if TYPE_B:
                    valid_writer_word.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_word_termination_loss + mean_TYPE_B_word_loc_reconstruct_loss + mean_TYPE_B_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_word_termination_loss, timestep)
                    valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_word_loc_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_word_WC_reconstruct_loss, timestep)
                if TYPE_C:
                    valid_writer_word.add_scalar('Loss/mean_TYPE_C_loss', mean_TYPE_C_word_termination_loss + mean_TYPE_C_word_loc_reconstruct_loss + mean_TYPE_C_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_termination_loss', mean_TYPE_C_word_termination_loss, timestep)
                    valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_C_loc_reconstruct_loss', mean_TYPE_C_word_loc_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_touch_reconstruct_loss', mean_TYPE_C_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_C_WC_reconstruct_loss', mean_TYPE_C_word_WC_reconstruct_loss, timestep)
                if TYPE_D:
                    valid_writer_word.add_scalar('Loss/mean_TYPE_D_loss', mean_TYPE_D_word_termination_loss + mean_TYPE_D_word_loc_reconstruct_loss + mean_TYPE_D_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_termination_loss', mean_TYPE_D_word_termination_loss, timestep)
                    valid_writer_word.add_scalar('Loss_Loc/mean_TYPE_D_loc_reconstruct_loss', mean_TYPE_D_word_loc_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_touch_reconstruct_loss', mean_TYPE_D_word_touch_reconstruct_loss, timestep)
                    valid_writer_word.add_scalar('Z_LOSS/mean_TYPE_D_WC_reconstruct_loss', mean_TYPE_D_word_WC_reconstruct_loss, timestep)

            if segment_loss:
                [total_segment_loss, mean_segment_W_consistency_loss, mean_ORIGINAL_segment_termination_loss, mean_ORIGINAL_segment_loc_reconstruct_loss, mean_ORIGINAL_segment_touch_reconstruct_loss, mean_TYPE_A_segment_termination_loss, mean_TYPE_A_segment_loc_reconstruct_loss, mean_TYPE_A_segment_touch_reconstruct_loss, mean_TYPE_B_segment_termination_loss, mean_TYPE_B_segment_loc_reconstruct_loss, mean_TYPE_B_segment_touch_reconstruct_loss, mean_TYPE_A_segment_WC_reconstruct_loss, mean_TYPE_B_segment_WC_reconstruct_loss] = segment_losses
                valid_writer_all.add_scalar('ALL/total_segment_loss', total_segment_loss, timestep)
                valid_writer_segment.add_scalar('Loss/mean_W_consistency_loss', mean_segment_W_consistency_loss, timestep)
                if ORIGINAL:
                    valid_writer_segment.add_scalar('Loss/mean_ORIGINAL_loss', mean_ORIGINAL_segment_termination_loss + mean_ORIGINAL_segment_loc_reconstruct_loss + mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_termination_loss', mean_ORIGINAL_segment_termination_loss, timestep)
                    valid_writer_segment.add_scalar('Loss_Loc/mean_ORIGINAL_loc_reconstruct_loss', mean_ORIGINAL_segment_loc_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_ORIGINAL_touch_reconstruct_loss', mean_ORIGINAL_segment_touch_reconstruct_loss, timestep)
                if TYPE_A:
                    valid_writer_segment.add_scalar('Loss/mean_TYPE_A_loss', mean_TYPE_A_segment_termination_loss + mean_TYPE_A_segment_loc_reconstruct_loss + mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_termination_loss', mean_TYPE_A_segment_termination_loss, timestep)
                    valid_writer_segment.add_scalar('Loss_Loc/mean_TYPE_A_loc_reconstruct_loss', mean_TYPE_A_segment_loc_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_touch_reconstruct_loss', mean_TYPE_A_segment_touch_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_A_WC_reconstruct_loss', mean_TYPE_A_segment_WC_reconstruct_loss, timestep)
                if TYPE_B:
                    valid_writer_segment.add_scalar('Loss/mean_TYPE_B_loss', mean_TYPE_B_segment_termination_loss + mean_TYPE_B_segment_loc_reconstruct_loss + mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_termination_loss', mean_TYPE_B_segment_termination_loss, timestep)
                    valid_writer_segment.add_scalar('Loss_Loc/mean_TYPE_B_loc_reconstruct_loss', mean_TYPE_B_segment_loc_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_touch_reconstruct_loss', mean_TYPE_B_segment_touch_reconstruct_loss, timestep)
                    valid_writer_segment.add_scalar('Z_LOSS/mean_TYPE_B_WC_reconstruct_loss', mean_TYPE_B_segment_WC_reconstruct_loss, timestep)

        if timestep % (num_writer * num_samples * 1000) == 0.0:
            torch.save({'timestep': timestep,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss.item(),
            }, 'model/'+str(timestep)+'.pt')

    writer.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for training the handwriting synthesis network.')

    parser.add_argument('--divider', type=float, default=5.0)
    parser.add_argument('--weight_dim', type=int, default=256)
    parser.add_argument('--sample', type=int, default=2)
    parser.add_argument('--device', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--num_writer', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.001)

    parser.add_argument('--sentence_loss', type=int, default=1)
    parser.add_argument('--word_loss', type=int, default=1)
    parser.add_argument('--segment_loss', type=int, default=1)

    parser.add_argument('--TYPE_A', type=int, default=1)
    parser.add_argument('--TYPE_B', type=int, default=1)
    parser.add_argument('--TYPE_C', type=int, default=1)
    parser.add_argument('--TYPE_D', type=int, default=1)
    parser.add_argument('--ORIGINAL', type=int, default=1)

    parser.add_argument('--VALIDATION', type=int, default=1)
    parser.add_argument('--no_char', type=int, default=0)
    parser.add_argument('--REC', type=int, default=1)
    parser.add_argument('--CHECKPOINT', type=int, default=0)

    main(parser.parse_args())