decoupled-style-descriptors / convenience.py
brayden-gg
added download button
323e91a
raw
history blame
23.5 kB
import os
import re
from random import random
import torch
import pickle
import argparse
import numpy as np
from helper import *
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from config.GlobalVariables import *
from tensorboardX import SummaryWriter
from SynthesisNetwork import SynthesisNetwork
from DataLoader import DataLoader
import svgwrite
# import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python
L = 256
def get_mean_global_W(net, loaded_data, device):
"""gets the mean global style vector for a given writer"""
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length] = loaded_data
batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in]
batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out]
batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length]
batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term]
batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char]
batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length]
batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in]
batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out]
batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length]
batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term]
batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char]
batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length]
with torch.no_grad():
word_inf_state_out = net.inf_state_fc1(batch_word_level_stroke_out[0])
word_inf_state_out = net.inf_state_relu(word_inf_state_out)
word_inf_state_out, _ = net.inf_state_lstm(word_inf_state_out)
user_word_level_char = batch_word_level_char[0]
user_word_level_term = batch_word_level_term[0]
original_Wc = []
word_batch_id = 0
curr_seq_len = batch_word_level_stroke_length[0][word_batch_id][0]
curr_char_len = batch_word_level_char_length[0][word_batch_id][0]
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(device)
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1)
split_ids = torch.nonzero(current_term)[:, 0]
char_vector_1 = net.char_vec_fc_1(char_vector)
char_vector_1 = net.char_vec_relu_1(char_vector_1)
char_out_1 = char_vector_1.unsqueeze(0)
char_out_1, (c, h) = net.char_lstm_1(char_out_1)
char_out_1 = char_out_1.squeeze(0)
char_out_1 = net.char_vec_fc2_1(char_out_1)
char_matrix_1 = char_out_1.view([-1, 1, 256, 256])
char_matrix_1 = char_matrix_1.squeeze(1)
char_matrix_inv_1 = torch.inverse(char_matrix_1)
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len]
W_c = torch.stack([W_c_t[i] for i in split_ids])
original_Wc.append(W_c)
W = torch.bmm(char_matrix_inv_1, W_c.unsqueeze(2)).squeeze(-1)
mean_global_W = torch.mean(W, 0)
return mean_global_W
def get_DSD(net, target_word, writer_mean_Ws, all_loaded_data, device):
"""
returns a style vector and character matrix for each character/segment in target_word
n is the number of writers
M is the number of characters in the target word
L is the latent vector size (in this case 256)
input:
- target_word, a string of length M to be converted to a DSD
- writer_mean_Ws, a list of n style vectors of size L
output:
- all_writer_Ws, a tensor of size n x M x L representing the style vectors for each writer and character
- all_writer_Cs, a tensor of size n x M x L x L representing the corresponding character matrix
"""
n = len(all_loaded_data)
M = len(target_word)
all_writer_Ws = torch.zeros(n, M, L)
all_writer_Cs = torch.zeros(n, M, L, L)
for i in range(n):
np.random.seed(0)
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length] = all_loaded_data[i]
available_segments = {}
for sid, sentence in enumerate(all_segment_level_char[0]):
for wid, word in enumerate(sentence):
segment = ''.join([CHARACTERS[i] for i in word])
split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid]))
if segment in available_segments:
available_segments[segment].append([all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids])
else:
available_segments[segment] = [[all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]]
index = 0
all_W = []
all_C = []
# while index <= len(target_word):
while index < len(target_word):
available = False
# Currently this just uses each character individually instead of the whole segment
# for end_index in range(len(target_word), index, -1):
# segment = target_word[index:end_index]
# print (segment)
segment = target_word[index]
if segment in available_segments: # method beta
# print(f'in dic - {segment}')
available = True
candidates = available_segments[segment]
segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))]
out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_out).to(device).unsqueeze(0))
out = net.inf_state_relu(out)
seg_W_c, (h_n, _) = net.inf_state_lstm(out)
character = segment[0] # take the first character of the segment?
# get character matrix using same method as method beta
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0)
out = net.char_vec_fc_1(char_vector)
out = net.char_vec_relu_1(out)
out, _ = net.char_lstm_1(out.unsqueeze(0))
out = out.squeeze(0)
out = net.char_vec_fc2_1(out)
char_matrix = out.view([-1, 256, 256])
inv_char_matrix = char_matrix.inverse()
id = split_ids[0][0]
W_c_vector = seg_W_c[0, id].squeeze()
# invert to get writer-independed DSD
W_vector = torch.bmm(inv_char_matrix, W_c_vector.repeat(inv_char_matrix.size(0), 1).unsqueeze(2))
all_W.append(W_vector)
all_C.append(char_matrix)
index += 1
if index == len(target_word):
break
if not available: # method alpha
character = target_word[index]
# print(f'no dic - {character}')
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0)
out = net.char_vec_fc_1(char_vector)
out = net.char_vec_relu_1(out)
out, _ = net.char_lstm_1(out.unsqueeze(0))
out = out.squeeze(0)
out = net.char_vec_fc2_1(out)
char_matrix = out.view([-1, 256, 256])
W_vector = writer_mean_Ws[i].repeat(char_matrix.size(0), 1).unsqueeze(2)
# all_W.append([W_vector])
all_W.append(W_vector)
all_C.append(char_matrix)
index += 1
all_writer_Ws[i, :, :] = torch.stack(all_W).squeeze()
all_writer_Cs[i, :, :, :] = torch.stack(all_C).squeeze()
return all_writer_Ws, all_writer_Cs
def get_writer_blend_W_c(writer_weights, all_Ws, all_Cs):
"""
generates character-dependent style-dependent DSDs for each character/segement in target_word,
averaging together the styles of the handwritings using provided weights
n is the number of writers
M is the number of characters in the target word
L is the latent vector size (in this case 256)
input:
- writer_weights, a list of length n weights for each writer that sum to one
- all_writer_Ws, an n x M x L tensor representing each weiter's style vector for every character
- all_writer_Cs, an n x M x L x L tensor representing the style's correspodning character matrix
output:
- an M x 1 x L tensor of M scharacter-dependent style-dependent DSDs
"""
n, M, _ = all_Ws.shape
weights_tensor = torch.tensor(writer_weights).repeat_interleave(M * L).reshape(n, M, L) # repeat accross remaining dimensions
W_vectors = (weights_tensor * all_Ws).sum(axis=0).unsqueeze(-1) # take weighted sum accross writers axis
char_matrices = all_Cs[0, :, :, :] # character matrices are independent of writer
W_cs = torch.bmm(char_matrices, W_vectors)
return W_cs.reshape(M, 1, L)
def get_character_blend_W_c(character_weights, all_Ws, all_Cs):
"""
generates a single character-dependent style-dependent DSD,
averaging together the characters using provided weights
M is the number of characters to blend
L is the latent vector size (in this case 256)
input:
- character_weights, a list of length M weights for each character that sum to one
- all_Ws, a 1 x M x L tensor representing the wwiter's style vector for each character
- all_Cs, 1 x M x L x L tensor representing the style's correspodning character matrix
output:
- a 1 x 1 x L tensor representing the character-dependent style-dependent DSDs
"""
M = len(character_weights)
W_vector = all_Ws[0, 0, :].unsqueeze(-1)
weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions
char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze() # take weighted sum accross characters axis
W_c = char_matrix @ W_vector
return W_c.reshape(1, 1, L)
def get_commands(net, target_word, all_W_c): # seems like target_word is only used for length
"""converts character-dependent style-dependent DSDs to a list of commands for drawing"""
all_commands = []
current_id = 0
while True:
word_Wc_rec_TYPE_D = []
TYPE_D_REF = []
cid = 0
for segment_batch_id in range(len(all_W_c)):
if len(TYPE_D_REF) == 0:
for each_segment_Wc in all_W_c[segment_batch_id]:
if cid >= current_id:
word_Wc_rec_TYPE_D.append(each_segment_Wc)
cid += 1
if len(word_Wc_rec_TYPE_D) > 0:
TYPE_D_REF.append(all_W_c[segment_batch_id][-1])
else:
for each_segment_Wc in all_W_c[segment_batch_id]:
magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0)
magic_inp = magic_inp.unsqueeze(0)
TYPE_D_out, (c, h) = net.magic_lstm(magic_inp)
TYPE_D_out = TYPE_D_out.squeeze(0)
word_Wc_rec_TYPE_D.append(TYPE_D_out[-1])
TYPE_D_REF.append(all_W_c[segment_batch_id][-1])
WC_ = torch.stack(word_Wc_rec_TYPE_D)
tmp_commands, res = net.sample_from_w_fix(WC_)
current_id += res
if len(all_commands) == 0:
all_commands.append(tmp_commands)
else:
all_commands.append(tmp_commands[1:])
if res < 0 or current_id >= len(target_word):
break
commands = []
px, py = 0, 100
for coms in all_commands:
for i, [dx, dy, t] in enumerate(coms):
x = px + dx * 5
y = py + dy * 5
commands.append([x, y, t])
px, py = x, y
commands = np.asarray(commands)
commands[:, 0] -= np.min(commands[:, 0])
return commands
def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device):
'''
Method creating gif of mdn samples
num_samples: number of samples to be inputted
max_scale: the maximum value used to scale SD while sampling (increment is based on num samples)
'''
words = target_word.split(' ')
us_target_word = re.sub(r"\s+", '_', target_word)
os.makedirs(f"./results/{us_target_word}_mdn_samples", exist_ok=True)
for i in range(num_samples):
net.scale_sd = scale_sd
net.clamp_mdn = clamp_mdn
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
word_Ws = []
word_Cs = []
for word in words:
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
word_Ws.append(writer_Ws)
word_Cs.append(writer_Cs)
im = draw_words(words, word_Ws, word_Cs, [1], net)
im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png')
# Convert fromes to video using ffmpeg
photos = ffmpeg.input(f'results/{us_target_word}_mdn_samples/sample_*.png', pattern_type='glob', framerate=10)
videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p")
videos.run(overwrite_output=True)
def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"):
"""Generates an image of handwritten text based on target_sentence"""
words = target_sentence.split(' ')
writer_mean_Ws = []
for loaded_data in all_loaded_data:
mean_global_W = get_mean_global_W(net, loaded_data, device)
writer_mean_Ws.append(mean_global_W)
word_Ws = []
word_Cs = []
for word in words:
writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
word_Ws.append(writer_Ws)
word_Cs.append(writer_Cs)
return draw_words(words, word_Ws, word_Cs, writer_weights, net)
def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"):
"""Generates an image of handwritten text based on target_sentence"""
width = 60
im = Image.fromarray(np.zeros([(grid_size + 1) * width, (grid_size + 1) * width]))
dr = ImageDraw.Draw(im)
M = len(letters)
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
# all_Ws = torch.zeros(1, M, L)
all_Cs = torch.zeros(1, M, L, L)
for i in range(M): # get corners of grid
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device)
# all_Ws[:, i, :] = W_vector
all_Cs[:, i, :, :] = char_matrix
all_Ws = mean_global_W.reshape(1, 1, L)
for i in range(grid_size):
for j in range(grid_size):
wx = i / (grid_size - 1)
wy = j / (grid_size - 1)
character_weights = [(1 - wx) * (1 - wy), # top left is 1 at (0, 0)
wx * (1 - wy), # top right is 1 at (1, 0)
(1 - wx) * wy, # bottom left is 1 at (0, 1)
wx * wy] # bottom right is 1 at (1, 1)
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
all_commands = get_commands(net, letters[0], all_W_c)
offset_x = i * width
offset_y = j * width
for [x, y, t] in all_commands:
if t == 0:
dr.line((
px + offset_x + width/2,
py + offset_y - width/2, # letters are shifted down for some reason
x + offset_x + width/2,
y + offset_y - width/2), 255, 1)
px, py = x, y
return im
def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"):
"""
Generates a video of interpolating between each provided writer
"""
n = len(all_loaded_data)
os.makedirs(f"./results/{target_sentence}_blend_frames", exist_ok=True)
words = target_sentence.split(' ')
writer_mean_Ws = []
for loaded_data in all_loaded_data:
mean_global_W = get_mean_global_W(net, loaded_data, device)
writer_mean_Ws.append(mean_global_W)
word_Ws = []
word_Cs = []
for word in words:
all_writer_Ws, all_writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
word_Ws.append(all_writer_Ws)
word_Cs.append(all_writer_Cs)
for i in range(n - 1):
for j in range(transition_time):
completion = j/(transition_time)
individual_weights = [1 - completion, completion]
writer_weights = [0] * i + individual_weights + [0] * (n - 2 - i)
im = draw_words(words, word_Ws, word_Cs, writer_weights, net)
im.convert("RGB").save(f"./results/{target_sentence}_blend_frames/frame_{str(i * transition_time + j).zfill(3)}.png")
# Convert fromes to video using ffmpeg
photos = ffmpeg.input(f"./results/{target_sentence}_blend_frames/frame_*.png", pattern_type='glob', framerate=10)
videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p")
videos.run(overwrite_output=True)
def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device):
'''
Method creating gif of mdn samples
num_samples: number of samples to be inputted
max_scale: the maximum value used to scale SD while sampling (increment is based on num samples)
'''
words = target_word.split(' ')
net.scale_sd = scale_sd
net.clamp_mdn = clamp_mdn
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
word_Ws = []
word_Cs = []
for word in words:
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
word_Ws.append(writer_Ws)
word_Cs.append(writer_Cs)
return draw_words(words, word_Ws, word_Cs, [1], net)
def sample_blended_chars(character_weights, letters, net, all_loaded_data, device="cpu"):
"""Generates an image of handwritten text based on target_sentence"""
M = len(letters)
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
all_Cs = torch.zeros(1, M, L, L)
for i in range(M): # get corners of grid
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device)
all_Cs[:, i, :, :] = char_matrix
all_Ws = mean_global_W.reshape(1, 1, L)
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
all_commands = get_commands(net, letters[0], all_W_c)
im = commands_to_image(all_commands, 100, 100, 30, 30)
return im
def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"):
"""Generates an image of handwritten text based on target_sentence"""
os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True) # make a folder for the frames
M = len(letters)
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
all_Cs = torch.zeros(1, M, L, L)
for i in range(M): # get corners of grid
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device)
all_Cs[:, i, :, :] = char_matrix
all_Ws = mean_global_W.reshape(1, 1, L)
for i in range(M - 1):
for j in range(transition_time):
completion = j / (transition_time - 1)
individual_weights = [1 - completion, completion]
character_weights = [0] * i + individual_weights + [0] * (M - 2 - i)
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
all_commands = get_commands(net, letters[i], all_W_c)
im = commands_to_image(all_commands, 100, 100, 25, 25)
im.convert("RGB").save(f"results/{''.join(letters)}_frames/frames_{str(i * transition_time + j).zfill(3)}.png")
# Convert fromes to video using ffmpeg
photos = ffmpeg.input(f"results/{''.join(letters)}_frames/frames_*.png", pattern_type='glob', framerate=24)
videos = photos.output(f"results/{''.join(letters)}_video.mov", vcodec="libx264", pix_fmt="yuv420p")
videos.run(overwrite_output=True)
def draw_words(words, word_Ws, word_Cs, writer_weights, net):
im = Image.fromarray(np.zeros([160, 750]))
dr = ImageDraw.Draw(im)
width = 50
for word, all_writer_Ws, all_writer_Cs in zip(words, word_Ws, word_Cs):
all_W_c = get_writer_blend_W_c(writer_weights, all_writer_Ws, all_writer_Cs)
all_commands = get_commands(net, word, all_W_c)
for [x, y, t] in all_commands:
if t == 0:
dr.line((px+width, py, x+width, y), 255, 1)
px, py = x, y
width += np.max(all_commands[:, 0]) + 25
return im
def draw_words_svg(words, word_Ws, word_Cs, writer_weights, net):
dwg = svgwrite.Drawing("output.svg", size=(750, 160), style="background-color: black;")
width = 50
for word, all_writer_Ws, all_writer_Cs in zip(words, word_Ws, word_Cs):
all_W_c = get_writer_blend_W_c(writer_weights, all_writer_Ws, all_writer_Cs)
all_commands = get_commands(net, word, all_W_c)
for [x, y, t] in all_commands:
if t == 0:
path.push("L", x + width, y)
else:
path = svgwrite.path.Path(stroke="white", stroke_width="1")
dwg.add(path)
path.push("M", x + width, y)
width += np.max(all_commands[:, 0]) + 25
return dwg
def commands_to_image(commands, imW, imH, xoff, yoff):
im = Image.fromarray(np.zeros([imW, imH]))
dr = ImageDraw.Draw(im)
for [x, y, t] in commands:
if t == 0:
dr.line((
px + xoff,
py - yoff, # letters are shifted down for some reason
x + xoff,
y - yoff), 255, 1)
px, py = x, y
return im
def commands_to_svg(commands, imW, imH, xoff):
dwg = svgwrite.Drawing("output.svg", size=(imW, imH), style="background-color:black")
for [x, y, t] in commands:
if t == 0:
path.push("L", x + xoff, y)
else:
path = svgwrite.path.Path(stroke="white", stroke_width="1")
dwg.add(path)
path.push("M", x + xoff, y)
return dwg