Spaces:
Build error
Build error
import torch | |
import argparse | |
import numpy as np | |
from helper import * | |
from config.GlobalVariables import * | |
from SynthesisNetwork import SynthesisNetwork | |
from DataLoader import DataLoader | |
import convenience | |
L = 256 | |
def main(params): | |
np.random.seed(0) | |
torch.manual_seed(0) | |
device = 'cpu' | |
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) | |
if not torch.cuda.is_available(): | |
try: # retrained model also contains loss in dict | |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"]) | |
except: | |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))) | |
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') | |
all_loaded_data = [] | |
for writer_id in params.writer_ids: | |
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples))) | |
all_loaded_data.append(loaded_data) | |
if params.output == "image": | |
if params.interpolate == "writer": | |
if len(params.blend_weights) != len(params.writer_ids): | |
raise ValueError("blend_weights must be same length as writer_ids") | |
im = convenience.sample_blended_writers(params.blend_weights, params.target_word, net, all_loaded_data, device) | |
im.convert("RGB").save(f'results/blend_{"+".join([str(i) for i in params.writer_ids])}.png') | |
elif params.interpolate == "character": | |
if len(params.blend_weights) != len(params.blend_chars): | |
raise ValueError("blend_weights must be same length as target_word") | |
im = convenience.sample_blended_chars(params.blend_weights, params.blend_chars, net, all_loaded_data, device) | |
im.convert("RGB").save(f'results/blend_{"+".join(params.blend_chars)}.png') | |
elif params.interpolate == "randomness": | |
if not 0 <= params.max_randomness <= 1: | |
raise ValueError("max_randomness must be between 0 and 1") | |
im = convenience.mdn_single_sample(params.target_word, params.scale_randomness, params.max_randomness, net, all_loaded_data, device) | |
im.convert("RGB").save(f"results/sample_{params.target_word.replace(' ', '_')}.png") | |
else: | |
raise ValueError("Invalid interpolation argument for outputting an image") | |
elif params.output == "grid": | |
if params.interpolate == "character": | |
if len(params.grid_chars) != 4: | |
raise ValueError("grid_chars must be given exactly four characters") | |
im = convenience.sample_character_grid(params.grid_chars, params.grid_size, net, all_loaded_data, device) | |
im.convert("RGB").save(f'results/grid_{"".join(params.grid_chars)}.png') | |
else: | |
raise ValueError("Invalid interpolation argument for outputting a grid") | |
elif params.output == "video": | |
if params.interpolate == "writer": | |
convenience.writer_interpolation_video(params.target_word, params.frames_per_step, net, all_loaded_data, device) | |
elif params.interpolate == "character": | |
convenience.char_interpolation_video(params.blend_chars, params.frames_per_step, net, all_loaded_data, device) | |
elif params.interpolate == "randomness": | |
if not 0 <= params.max_randomness <= 1: | |
raise ValueError("max_randomness must be between 0 and 1") | |
convenience.mdn_video(params.target_word, params.num_random_samples, params.scale_randomness, params.max_randomness, net, all_loaded_data, device) | |
else: | |
raise ValueError("Invalid interpolation argument for outputting a video") | |
else: | |
raise ValueError("Invalid output") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.') | |
# parser.add_argument('--writer_id', type=int, default=80) | |
parser.add_argument('--num_samples', type=int, default=10) | |
parser.add_argument('--generating_default', type=int, default=0) | |
parser.add_argument('--output', type=str, default="image", choices=["image", "grid", "video"]) | |
parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"]) | |
# PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION: | |
# IF IMAGE - weights to use for a single sample of interpolation | |
parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5]) | |
# IF VIDEO - the number of frames for each character/writer | |
parser.add_argument('--frames_per_step', type=int, default=10) | |
# PARAMS IF WRITER INTERPOLATION: | |
parser.add_argument('--target_word', type=str, default="hello world") | |
parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120]) | |
# PARAMS IF CHARACTER INTERPOLATION: | |
# IF VIDEO OR BLEND | |
parser.add_argument('--blend_chars', type=str, nargs="+", default=["a", "b", "c", "d", "e"]) | |
# IF GRID | |
parser.add_argument('--grid_chars', type=str, nargs="+", default=["y", "s", "u", "n"]) | |
parser.add_argument('--grid_size', type=int, default=10) | |
# PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored): | |
parser.add_argument('--max_randomness', type=float, default=1) | |
parser.add_argument('--scale_randomness', type=float, default=0.5) | |
parser.add_argument('--num_random_samples', type=int, default=10) | |
main(parser.parse_args()) | |