import torch import argparse import numpy as np from helper import * from config.GlobalVariables import * from SynthesisNetwork import SynthesisNetwork from DataLoader import DataLoader import convenience L = 256 def main(params): np.random.seed(0) torch.manual_seed(0) device = 'cpu' net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) if not torch.cuda.is_available(): try: # retrained model also contains loss in dict net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"]) except: net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))) dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') all_loaded_data = [] for writer_id in params.writer_ids: loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples))) all_loaded_data.append(loaded_data) if params.output == "image": if params.interpolate == "writer": if len(params.blend_weights) != len(params.writer_ids): raise ValueError("blend_weights must be same length as writer_ids") im = convenience.sample_blended_writers(params.blend_weights, params.target_word, net, all_loaded_data, device) im.convert("RGB").save(f'results/blend_{"+".join([str(i) for i in params.writer_ids])}.png') elif params.interpolate == "character": if len(params.blend_weights) != len(params.blend_chars): raise ValueError("blend_weights must be same length as target_word") im = convenience.sample_blended_chars(params.blend_weights, params.blend_chars, net, all_loaded_data, device) im.convert("RGB").save(f'results/blend_{"+".join(params.blend_chars)}.png') elif params.interpolate == "randomness": if not 0 <= params.max_randomness <= 1: raise ValueError("max_randomness must be between 0 and 1") im = convenience.mdn_single_sample(params.target_word, params.scale_randomness, params.max_randomness, net, all_loaded_data, device) im.convert("RGB").save(f"results/sample_{params.target_word.replace(' ', '_')}.png") else: raise ValueError("Invalid interpolation argument for outputting an image") elif params.output == "grid": if params.interpolate == "character": if len(params.grid_chars) != 4: raise ValueError("grid_chars must be given exactly four characters") im = convenience.sample_character_grid(params.grid_chars, params.grid_size, net, all_loaded_data, device) im.convert("RGB").save(f'results/grid_{"".join(params.grid_chars)}.png') else: raise ValueError("Invalid interpolation argument for outputting a grid") elif params.output == "video": if params.interpolate == "writer": convenience.writer_interpolation_video(params.target_word, params.frames_per_step, net, all_loaded_data, device) elif params.interpolate == "character": convenience.char_interpolation_video(params.blend_chars, params.frames_per_step, net, all_loaded_data, device) elif params.interpolate == "randomness": if not 0 <= params.max_randomness <= 1: raise ValueError("max_randomness must be between 0 and 1") convenience.mdn_video(params.target_word, params.num_random_samples, params.scale_randomness, params.max_randomness, net, all_loaded_data, device) else: raise ValueError("Invalid interpolation argument for outputting a video") else: raise ValueError("Invalid output") if __name__ == '__main__': parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.') # parser.add_argument('--writer_id', type=int, default=80) parser.add_argument('--num_samples', type=int, default=10) parser.add_argument('--generating_default', type=int, default=0) parser.add_argument('--output', type=str, default="image", choices=["image", "grid", "video"]) parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"]) # PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION: # IF IMAGE - weights to use for a single sample of interpolation parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5]) # IF VIDEO - the number of frames for each character/writer parser.add_argument('--frames_per_step', type=int, default=10) # PARAMS IF WRITER INTERPOLATION: parser.add_argument('--target_word', type=str, default="hello world") parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120]) # PARAMS IF CHARACTER INTERPOLATION: # IF VIDEO OR BLEND parser.add_argument('--blend_chars', type=str, nargs="+", default=["a", "b", "c", "d", "e"]) # IF GRID parser.add_argument('--grid_chars', type=str, nargs="+", default=["y", "s", "u", "n"]) parser.add_argument('--grid_size', type=int, default=10) # PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored): parser.add_argument('--max_randomness', type=float, default=1) parser.add_argument('--scale_randomness', type=float, default=0.5) parser.add_argument('--num_random_samples', type=int, default=10) main(parser.parse_args())