decoupled-style-descriptors / interpolation.py
brayden-gg
switched to SVG rendering
904e4e5
raw
history blame
5.52 kB
import torch
import argparse
import numpy as np
from helper import *
from config.GlobalVariables import *
from SynthesisNetwork import SynthesisNetwork
from DataLoader import DataLoader
import convenience
L = 256
def main(params):
np.random.seed(0)
torch.manual_seed(0)
device = 'cpu'
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
if not torch.cuda.is_available():
try: # retrained model also contains loss in dict
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
except:
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
all_loaded_data = []
for writer_id in params.writer_ids:
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
all_loaded_data.append(loaded_data)
if params.output == "image":
if params.interpolate == "writer":
if len(params.blend_weights) != len(params.writer_ids):
raise ValueError("blend_weights must be same length as writer_ids")
im = convenience.sample_blended_writers(params.blend_weights, params.target_word, net, all_loaded_data, device)
im.convert("RGB").save(f'results/blend_{"+".join([str(i) for i in params.writer_ids])}.png')
elif params.interpolate == "character":
if len(params.blend_weights) != len(params.blend_chars):
raise ValueError("blend_weights must be same length as target_word")
im = convenience.sample_blended_chars(params.blend_weights, params.blend_chars, net, all_loaded_data, device)
im.convert("RGB").save(f'results/blend_{"+".join(params.blend_chars)}.png')
elif params.interpolate == "randomness":
if not 0 <= params.max_randomness <= 1:
raise ValueError("max_randomness must be between 0 and 1")
im = convenience.mdn_single_sample(params.target_word, params.scale_randomness, params.max_randomness, net, all_loaded_data, device)
im.convert("RGB").save(f"results/sample_{params.target_word.replace(' ', '_')}.png")
else:
raise ValueError("Invalid interpolation argument for outputting an image")
elif params.output == "grid":
if params.interpolate == "character":
if len(params.grid_chars) != 4:
raise ValueError("grid_chars must be given exactly four characters")
im = convenience.sample_character_grid(params.grid_chars, params.grid_size, net, all_loaded_data, device)
im.convert("RGB").save(f'results/grid_{"".join(params.grid_chars)}.png')
else:
raise ValueError("Invalid interpolation argument for outputting a grid")
elif params.output == "video":
if params.interpolate == "writer":
convenience.writer_interpolation_video(params.target_word, params.frames_per_step, net, all_loaded_data, device)
elif params.interpolate == "character":
convenience.char_interpolation_video(params.blend_chars, params.frames_per_step, net, all_loaded_data, device)
elif params.interpolate == "randomness":
if not 0 <= params.max_randomness <= 1:
raise ValueError("max_randomness must be between 0 and 1")
convenience.mdn_video(params.target_word, params.num_random_samples, params.scale_randomness, params.max_randomness, net, all_loaded_data, device)
else:
raise ValueError("Invalid interpolation argument for outputting a video")
else:
raise ValueError("Invalid output")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')
# parser.add_argument('--writer_id', type=int, default=80)
parser.add_argument('--num_samples', type=int, default=10)
parser.add_argument('--generating_default', type=int, default=0)
parser.add_argument('--output', type=str, default="image", choices=["image", "grid", "video"])
parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])
# PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
# IF IMAGE - weights to use for a single sample of interpolation
parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
# IF VIDEO - the number of frames for each character/writer
parser.add_argument('--frames_per_step', type=int, default=10)
# PARAMS IF WRITER INTERPOLATION:
parser.add_argument('--target_word', type=str, default="hello world")
parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
# PARAMS IF CHARACTER INTERPOLATION:
# IF VIDEO OR BLEND
parser.add_argument('--blend_chars', type=str, nargs="+", default=["a", "b", "c", "d", "e"])
# IF GRID
parser.add_argument('--grid_chars', type=str, nargs="+", default=["y", "s", "u", "n"])
parser.add_argument('--grid_size', type=int, default=10)
# PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
parser.add_argument('--max_randomness', type=float, default=1)
parser.add_argument('--scale_randomness', type=float, default=0.5)
parser.add_argument('--num_random_samples', type=int, default=10)
main(parser.parse_args())