import os import torch import pickle import argparse import numpy as np from helper import * from PIL import Image import torch.nn as nn import torch.optim as optim from config.GlobalVariables import * from tensorboardX import SummaryWriter from SynthesisNetwork import SynthesisNetwork from DataLoader import DataLoader def main(params): np.random.seed(0) torch.manual_seed(0) device = 'cpu' dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) if not torch.cuda.is_available(): try: # retrained model also contains loss in dict net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"]) except: net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))) [_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='TRAIN', uid=params.writer_id, tids=[i for i in range(params.num_samples)]) batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in] batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out] batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length] batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term] batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char] batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length] batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in] batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out] batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length] batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term] batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char] batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length] if params.generating_default == 1: with torch.no_grad(): commands_list = net.sample([ batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length, batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in, batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term, batch_segment_level_char, batch_segment_level_char_length]) [t_commands, o_commands, a_commands, b_commands, c_commands, d_commands] = commands_list dst = Image.new('RGB', (750, 640)) dst.paste(draw_commands(t_commands), (0, 0)) dst.paste(draw_commands(o_commands), (0, 160)) dst.paste(draw_commands(a_commands), (0, 320)) dst.paste(draw_commands(d_commands), (0, 480)) dst.save(f'results/default.png') with torch.no_grad(): word_inf_state_out = net.inf_state_fc1(batch_word_level_stroke_out[0]) word_inf_state_out = net.inf_state_relu(word_inf_state_out) word_inf_state_out, _ = net.inf_state_lstm(word_inf_state_out) user_word_level_char = batch_word_level_char[0] user_word_level_term = batch_word_level_term[0] original_Wc = [] word_batch_id = 0 curr_seq_len = batch_word_level_stroke_length[0][word_batch_id][0] curr_char_len = batch_word_level_char_length[0][word_batch_id][0] char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(device) current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1) split_ids = torch.nonzero(current_term)[:,0] char_vector_1 = net.char_vec_fc_1(char_vector) char_vector_1 = net.char_vec_relu_1(char_vector_1) char_out_1 = char_vector_1.unsqueeze(0) char_out_1, (c,h) = net.char_lstm_1(char_out_1) char_out_1 = char_out_1.squeeze(0) char_out_1 = net.char_vec_fc2_1(char_out_1) char_matrix_1 = char_out_1.view([-1,1,256,256]) char_matrix_1 = char_matrix_1.squeeze(1) char_matrix_inv_1 = torch.inverse(char_matrix_1) W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len] W_c = torch.stack([W_c_t[i] for i in split_ids]) original_Wc.append(W_c) W = torch.bmm(char_matrix_inv_1, W_c.unsqueeze(2)).squeeze(-1) mean_global_W = torch.mean(W, 0) def sample_word(target_word): available_segments = {} for sid, sentence in enumerate(all_segment_level_char[0]): for wid, word in enumerate(sentence): segment = ''.join([CHARACTERS[i] for i in word]) split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid])) if segment in available_segments: available_segments[segment].append([all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]) else: available_segments[segment] = [[all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]] index = 0 all_W_c = [] while index <= len(target_word): available = False for end_index in range(len(target_word), index, -1): segment = target_word[index:end_index] # print (segment) if segment in available_segments: # print (f'in dic - {segment}') available = True candidates = available_segments[segment] segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))] out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_out).to(device).unsqueeze(0)) out = net.inf_state_relu(out) seg_W_c, _ = net.inf_state_lstm(out) tmp = [] for id in split_ids[0]: tmp.append(seg_W_c[0, id].squeeze()) all_W_c.append(tmp) index = end_index if index == len(target_word): break if not available: character = target_word[index] # print (f'no dic - {character}') char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) out = net.char_vec_fc_1(char_vector) out = net.char_vec_relu_1(out) out, _ = net.char_lstm_1(out.unsqueeze(0)) out = out.squeeze(0) out = net.char_vec_fc2_1(out) char_matrix = out.view([-1, 256, 256]) TYPE_A_WC = torch.bmm(char_matrix, mean_global_W.repeat(char_matrix.size(0), 1).unsqueeze(2)).squeeze() all_W_c.append([TYPE_A_WC]) index += 1 all_commands = [] current_id = 0 while True: word_Wc_rec_TYPE_D = [] TYPE_D_REF = [] cid = 0 for segment_batch_id in range(len(all_W_c)): if len(TYPE_D_REF) == 0: for each_segment_Wc in all_W_c[segment_batch_id]: if cid >= current_id: word_Wc_rec_TYPE_D.append(each_segment_Wc) cid += 1 if len(word_Wc_rec_TYPE_D) > 0: TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) else: for each_segment_Wc in all_W_c[segment_batch_id]: magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0) magic_inp = magic_inp.unsqueeze(0) TYPE_D_out, (c,h) = net.magic_lstm(magic_inp) TYPE_D_out = TYPE_D_out.squeeze(0) word_Wc_rec_TYPE_D.append(TYPE_D_out[-1]) TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) WC_ = torch.stack(word_Wc_rec_TYPE_D) tmp_commands, res = net.sample_from_w_fix(WC_, target_word) current_id = current_id + res if len(all_commands) == 0: all_commands.append(tmp_commands) else: all_commands.append(tmp_commands[1:]) if res < 0 or current_id >= len(target_word): break # tmp_commands = net.sample_from_w_fix(torch.stack(tmp_WC), _, target_word) commands = [] px, py = 0, 100 for coms in all_commands: for i, [dx, dy, t] in enumerate(coms): x = px + dx * 5 y = py + dy * 5 commands.append([x,y,t]) px, py = x, y commands = np.asarray(commands) commands[:, 0] -= np.min(commands[:, 0]) return commands def sample(target_sentence): words = target_sentence.split(' ') im = Image.fromarray(np.zeros([160, 750])) dr = ImageDraw.Draw(im) width = 50 for word in words: all_commands = sample_word(word) for [x,y,t] in all_commands: if t == 0: dr.line((px+width, py, x+width, y), 255, 1) px, py = x, y width += np.max(all_commands[:, 0]) + 25 # im.convert("RGB").save(f'results/{target_word}.png') im.convert("RGB").save(f'results/hello.png') def sample_word2(target_word): available_segments = {} for sid, sentence in enumerate(all_segment_level_char[0]): for wid, word in enumerate(sentence): segment = ''.join([CHARACTERS[i] for i in word]) split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid])) if segment in available_segments: available_segments[segment].append([all_segment_level_stroke_in[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]) else: available_segments[segment] = [[all_segment_level_stroke_in[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]] index = 0 all_W_c = [] all_commands = [] know_chars = [] while index <= len(target_word): available = False for end_index in range(len(target_word), index, -1): segment = target_word[index:end_index] if segment in available_segments: # print (f'in dic - {segment}') available = True candidates = available_segments[segment] segment_level_stroke_in, segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))] out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_in).to(device).unsqueeze(0)) out = net.inf_state_relu(out) seg_W_c, _ = net.inf_state_lstm(out) tmp = [] for id in split_ids[0]: # print (id) tmp.append(seg_W_c[0, id].squeeze()) all_W_c.append(tmp) index = end_index for i in range(index, end_index): know_chars.append(i) commands = [] px, py = 0, 100 for i, [dx, dy, t] in enumerate(segment_level_stroke_out): x = px + dx * 5 y = py + dy * 5 commands.append([x,y,t]) px, py = x, y commands = np.asarray(commands) commands[:, 0] -= np.min(commands[:, 0]) all_commands.append(commands) if index == len(target_word): break if not available: character = target_word[index] # print (f'no dic - {character}') char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) out = net.char_vec_fc_1(char_vector) out = net.char_vec_relu_1(out) out, _ = net.char_lstm_1(out.unsqueeze(0)) out = out.squeeze(0) out = net.char_vec_fc2_1(out) char_matrix = out.view([-1, 256, 256]) TYPE_A_WC = torch.bmm(char_matrix, mean_global_W.repeat(char_matrix.size(0), 1).unsqueeze(2)).squeeze().unsqueeze(0) index += 1 temp_commands = net.sample_from_w(TYPE_A_WC, character) commands = [] px, py = 0, 100 for i, [dx, dy, t] in enumerate(temp_commands): x = px + dx * 5 y = py + dy * 5 commands.append([x,y,t]) px, py = x, y commands = np.asarray(commands) commands[:, 0] -= np.min(commands[:, 0]) all_commands.append(commands) return all_commands def sample2(target_sentence): words = target_sentence.split(' ') im = Image.fromarray(np.zeros([160, 750])) dr = ImageDraw.Draw(im) width = 50 for word in words: all_commands = sample_word2(word) for c in all_commands: for [x,y,t] in c: if t == 0: dr.line((px+width, py, x+width, y), 255, 1) px, py = x, y width += np.max(c[:, 0]) + 5 width += 25 # im.convert("RGB").save(f'results/{target_word}.png') im.convert("RGB").save(f'results/hello2.png') while True: target_word = input("Type a sentence to generate : ") if len(target_word) > 0: if params.direct_use == 0: sample(target_word) else: sample2(target_word) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.') parser.add_argument('--writer_id', type=int, default=80) parser.add_argument('--num_samples', type=int, default=10) parser.add_argument('--generating_default', type=int, default=0) parser.add_argument('--direct_use', type=int, default=0) main(parser.parse_args())