Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from torch.distributions import MultivariateNormal | |
import math | |
import numpy as np | |
from helper import gaussian_2d | |
from config.GlobalVariables import * | |
class SynthesisNetwork(nn.Module): | |
def __init__(self, weight_dim=512, num_layers=3, scale_sd=1, clamp_mdn=0, sentence_loss=True, word_loss=True, segment_loss=True, TYPE_A=True, TYPE_B=True, TYPE_C=True, TYPE_D=True, ORIGINAL=True, REC=True): | |
super(SynthesisNetwork, self).__init__() | |
self.num_mixtures = 20 | |
self.num_layers = num_layers | |
self.weight_dim = weight_dim | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.sentence_loss = sentence_loss | |
self.word_loss = word_loss | |
self.segment_loss = segment_loss | |
self.ORIGINAL = ORIGINAL | |
self.TYPE_A = TYPE_A | |
self.TYPE_B = TYPE_B | |
self.TYPE_C = TYPE_C | |
self.TYPE_D = TYPE_D | |
self.REC = REC | |
self.magic_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers) | |
self.char_vec_fc_1 = nn.Linear(len(CHARACTERS), self.weight_dim) | |
self.char_vec_relu_1 = nn.LeakyReLU(negative_slope=0.1) | |
self.char_lstm_1 = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers) | |
self.char_vec_fc2_1 = nn.Linear(self.weight_dim, self.weight_dim * self.weight_dim) | |
# inference | |
self.inf_state_fc1 = nn.Linear(3, self.weight_dim) | |
self.inf_state_relu = nn.LeakyReLU(negative_slope=0.1) | |
self.inf_state_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers) | |
self.W_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers) | |
# generation | |
self.gen_state_fc1 = nn.Linear(3, self.weight_dim) | |
self.gen_state_relu = nn.LeakyReLU(negative_slope=0.1) | |
self.gen_state_lstm1 = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers) | |
self.gen_state_lstm2 = nn.LSTM(self.weight_dim * 2, self.weight_dim * 2, batch_first=True, num_layers=self.num_layers) | |
self.gen_state_fc2 = nn.Linear(self.weight_dim * 2, self.num_mixtures * 6 + 1) | |
self.term_fc1 = nn.Linear(self.weight_dim * 2, self.weight_dim) | |
self.term_relu1 = nn.LeakyReLU(negative_slope=0.1) | |
self.term_fc2 = nn.Linear(self.weight_dim, self.weight_dim) | |
self.term_relu2 = nn.LeakyReLU(negative_slope=0.1) | |
self.term_fc3 = nn.Linear(self.weight_dim, 1) | |
self.term_sigmoid = nn.Sigmoid() | |
self.mdn_sigmoid = nn.Sigmoid() | |
self.mdn_tanh = nn.Tanh() | |
self.mdn_softmax = nn.Softmax(dim=1) | |
self.scale_sd = scale_sd # how much to scale the standard deviation of the gaussians | |
self.clamp_mdn = clamp_mdn # total percent of disrubution to allow sampling from | |
self.mdn_bce_loss = nn.BCEWithLogitsLoss() | |
self.term_bce_loss = nn.BCEWithLogitsLoss() | |
def forward(self, inputs): | |
[sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_stroke_length, sentence_level_term, sentence_level_char, sentence_level_char_length, word_level_stroke_in, word_level_stroke_out, word_level_stroke_length, word_level_term, word_level_char, word_level_char_length, segment_level_stroke_in, segment_level_stroke_out, segment_level_stroke_length, segment_level_term, segment_level_char, segment_level_char_length] = inputs | |
ALL_sentence_W_consistency_loss = [] | |
ALL_ORIGINAL_sentence_termination_loss = [] | |
ALL_ORIGINAL_sentence_loc_reconstruct_loss = [] | |
ALL_ORIGINAL_sentence_touch_reconstruct_loss = [] | |
ALL_TYPE_A_sentence_termination_loss = [] | |
ALL_TYPE_A_sentence_loc_reconstruct_loss = [] | |
ALL_TYPE_A_sentence_touch_reconstruct_loss = [] | |
ALL_TYPE_A_sentence_WC_reconstruct_loss = [] | |
ALL_TYPE_B_sentence_termination_loss = [] | |
ALL_TYPE_B_sentence_loc_reconstruct_loss = [] | |
ALL_TYPE_B_sentence_touch_reconstruct_loss = [] | |
ALL_TYPE_B_sentence_WC_reconstruct_loss = [] | |
ALL_word_W_consistency_loss = [] | |
ALL_ORIGINAL_word_termination_loss = [] | |
ALL_ORIGINAL_word_loc_reconstruct_loss = [] | |
ALL_ORIGINAL_word_touch_reconstruct_loss = [] | |
ALL_TYPE_A_word_termination_loss = [] | |
ALL_TYPE_A_word_loc_reconstruct_loss = [] | |
ALL_TYPE_A_word_touch_reconstruct_loss = [] | |
ALL_TYPE_A_word_WC_reconstruct_loss = [] | |
ALL_TYPE_B_word_termination_loss = [] | |
ALL_TYPE_B_word_loc_reconstruct_loss = [] | |
ALL_TYPE_B_word_touch_reconstruct_loss = [] | |
ALL_TYPE_B_word_WC_reconstruct_loss = [] | |
ALL_TYPE_C_word_termination_loss = [] | |
ALL_TYPE_C_word_loc_reconstruct_loss = [] | |
ALL_TYPE_C_word_touch_reconstruct_loss = [] | |
ALL_TYPE_C_word_WC_reconstruct_loss = [] | |
ALL_TYPE_D_word_termination_loss = [] | |
ALL_TYPE_D_word_loc_reconstruct_loss = [] | |
ALL_TYPE_D_word_touch_reconstruct_loss = [] | |
ALL_TYPE_D_word_WC_reconstruct_loss = [] | |
ALL_word_Wcs_reconstruct_TYPE_A = [] | |
ALL_word_Wcs_reconstruct_TYPE_B = [] | |
ALL_word_Wcs_reconstruct_TYPE_C = [] | |
ALL_word_Wcs_reconstruct_TYPE_D = [] | |
SUPER_ALL_segment_W_consistency_loss = [] | |
SUPER_ALL_ORIGINAL_segment_termination_loss = [] | |
SUPER_ALL_ORIGINAL_segment_loc_reconstruct_loss = [] | |
SUPER_ALL_ORIGINAL_segment_touch_reconstruct_loss = [] | |
SUPER_ALL_TYPE_A_segment_termination_loss = [] | |
SUPER_ALL_TYPE_A_segment_loc_reconstruct_loss = [] | |
SUPER_ALL_TYPE_A_segment_touch_reconstruct_loss = [] | |
SUPER_ALL_TYPE_A_segment_WC_reconstruct_loss = [] | |
SUPER_ALL_TYPE_B_segment_termination_loss = [] | |
SUPER_ALL_TYPE_B_segment_loc_reconstruct_loss = [] | |
SUPER_ALL_TYPE_B_segment_touch_reconstruct_loss = [] | |
SUPER_ALL_TYPE_B_segment_WC_reconstruct_loss = [] | |
SUPER_ALL_segment_Wcs_reconstruct_TYPE_A = [] | |
SUPER_ALL_segment_Wcs_reconstruct_TYPE_B = [] | |
# if self.sentece_loss: | |
for uid in range(len(sentence_level_stroke_in)): | |
if self.sentence_loss: | |
user_sentence_level_stroke_in = sentence_level_stroke_in[uid] | |
user_sentence_level_stroke_out = sentence_level_stroke_out[uid] | |
user_sentence_level_stroke_length = sentence_level_stroke_length[uid] | |
user_sentence_level_term = sentence_level_term[uid] | |
user_sentence_level_char = sentence_level_char[uid] | |
user_sentence_level_char_length = sentence_level_char_length[uid] | |
sentence_batch_size = len(user_sentence_level_stroke_in) | |
sentence_inf_state_out = self.inf_state_fc1(user_sentence_level_stroke_out) | |
sentence_inf_state_out = self.inf_state_relu(sentence_inf_state_out) | |
sentence_inf_state_out, (c,h) = self.inf_state_lstm(sentence_inf_state_out) | |
sentence_gen_state_out = self.gen_state_fc1(user_sentence_level_stroke_in) | |
sentence_gen_state_out = self.gen_state_relu(sentence_gen_state_out) | |
sentence_gen_state_out, (c,h) = self.gen_state_lstm1(sentence_gen_state_out) | |
sentence_Ws = [] | |
sentence_Wc_rec_TYPE_ = [] | |
sentence_SPLITS = [] | |
sentence_Cs_1 = [] | |
sentence_unique_char_matrices_1 = [] | |
for sentence_batch_id in range(sentence_batch_size): | |
curr_seq_len = user_sentence_level_stroke_length[sentence_batch_id][0] | |
curr_char_len = user_sentence_level_char_length[sentence_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_sentence_level_char[sentence_batch_id][:curr_char_len]].to(self.device) | |
current_term = user_sentence_level_term[sentence_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
char_vector_1 = self.char_vec_fc_1(char_vector) | |
char_vector_1 = self.char_vec_relu_1(char_vector_1) | |
unique_char_matrices_1 = [] | |
for cid in range(len(char_vector)): | |
# Tower 1 | |
unique_char_vector_1 = char_vector_1[cid:cid+1] | |
unique_char_input_1 = unique_char_vector_1.unsqueeze(0) | |
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1) | |
unique_char_out_1 = unique_char_out_1.squeeze(0) | |
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1) | |
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1) | |
unique_char_matrices_1.append(unique_char_matrix_1) | |
# Tower 1 | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c,h) = self.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = self.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = sentence_inf_state_out[sentence_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
# W = torch.bmm(char_matrix_inv, W_c.unsqueeze(2)).squeeze(-1) | |
# C1C2C3W = Wc | |
# W = C3-1 C2-1 C1-1 Wc | |
W = torch.bmm(char_matrix_inv_1, | |
W_c.unsqueeze(2)).squeeze(-1) | |
sentence_Ws.append(W) | |
sentence_Wc_rec_TYPE_.append(W_c) | |
sentence_Cs_1.append(char_matrix_1) | |
sentence_SPLITS.append(split_ids) | |
sentence_unique_char_matrices_1.append(unique_char_matrices_1) | |
sentence_Ws_stacked = torch.cat(sentence_Ws, 0) | |
sentence_Ws_reshaped = sentence_Ws_stacked.view([-1,self.weight_dim]) | |
sentence_W_mean = sentence_Ws_reshaped.mean(0) | |
sentence_W_mean_repeat = sentence_W_mean.repeat(sentence_Ws_reshaped.size(0),1) | |
sentence_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(sentence_W_mean_repeat - sentence_Ws_reshaped, sentence_W_mean_repeat - sentence_Ws_reshaped), -1)) | |
ALL_sentence_W_consistency_loss.append(sentence_Ws_consistency_loss) | |
ORIGINAL_sentence_termination_loss = [] | |
ORIGINAL_sentence_loc_reconstruct_loss = [] | |
ORIGINAL_sentence_touch_reconstruct_loss = [] | |
TYPE_A_sentence_termination_loss = [] | |
TYPE_A_sentence_loc_reconstruct_loss = [] | |
TYPE_A_sentence_touch_reconstruct_loss = [] | |
TYPE_B_sentence_termination_loss = [] | |
TYPE_B_sentence_loc_reconstruct_loss = [] | |
TYPE_B_sentence_touch_reconstruct_loss = [] | |
sentence_Wcs_reconstruct_TYPE_A = [] | |
sentence_Wcs_reconstruct_TYPE_B = [] | |
for sentence_batch_id in range(sentence_batch_size): | |
sentence_level_gen_encoded = sentence_gen_state_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]] | |
sentence_level_target_eos = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,2] | |
sentence_level_target_x = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,0:1] | |
sentence_level_target_y = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,1:2] | |
sentence_level_target_term = user_sentence_level_term[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]] | |
# ORIGINAL | |
if self.ORIGINAL: | |
sentence_W_lstm_in_ORIGINAL = [] | |
curr_id = 0 | |
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]): | |
sentence_W_lstm_in_ORIGINAL.append(sentence_Wc_rec_TYPE_[sentence_batch_id][curr_id]) | |
if i in sentence_SPLITS[sentence_batch_id]: | |
curr_id += 1 | |
sentence_W_lstm_in_ORIGINAL = torch.stack(sentence_W_lstm_in_ORIGINAL) | |
sentence_Wc_t_ORIGINAL = sentence_W_lstm_in_ORIGINAL | |
sentence_gen_lstm2_in_ORIGINAL = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_ORIGINAL], -1) | |
sentence_gen_lstm2_in_ORIGINAL = sentence_gen_lstm2_in_ORIGINAL.unsqueeze(0) | |
sentence_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_ORIGINAL) | |
sentence_gen_out_ORIGINAL = sentence_gen_out_ORIGINAL.squeeze(0) | |
mdn_out_ORIGINAL = self.gen_state_fc2(sentence_gen_out_ORIGINAL) | |
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1] | |
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1) | |
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3 | |
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3 | |
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL) | |
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc1(sentence_gen_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL) | |
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL) | |
gaussian_ORIGINAL = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL) | |
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5) | |
ORIGINAL_sentence_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), sentence_level_target_term) | |
ORIGINAL_sentence_loc_loss = torch.mean(loss_gaussian_ORIGINAL) | |
ORIGINAL_sentence_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), sentence_level_target_eos) | |
ORIGINAL_sentence_termination_loss.append(ORIGINAL_sentence_term_loss) | |
ORIGINAL_sentence_loc_reconstruct_loss.append(ORIGINAL_sentence_loc_loss) | |
ORIGINAL_sentence_touch_reconstruct_loss.append(ORIGINAL_sentence_touch_loss) | |
# TYPE A | |
if self.TYPE_A: | |
sentence_C1 = sentence_Cs_1[sentence_batch_id] | |
# sentence_Wc_rec_TYPE_A = torch.bmm(sentence_Cs[sentence_batch_id], sentence_W_mean.repeat(sentence_Cs[sentence_batch_id].size(0),1).unsqueeze(2)).squeeze(-1) | |
sentence_Wc_rec_TYPE_A = torch.bmm(sentence_C1, \ | |
sentence_W_mean.repeat(sentence_C1.size(0),1).unsqueeze(2)).squeeze(-1) | |
sentence_Wcs_reconstruct_TYPE_A.append(sentence_Wc_rec_TYPE_A) | |
sentence_W_lstm_in_TYPE_A = [] | |
curr_id = 0 | |
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]): | |
sentence_W_lstm_in_TYPE_A.append(sentence_Wc_rec_TYPE_A[curr_id]) | |
if i in sentence_SPLITS[sentence_batch_id]: | |
curr_id += 1 | |
sentence_Wc_t_rec_TYPE_A = torch.stack(sentence_W_lstm_in_TYPE_A) | |
sentence_gen_lstm2_in_TYPE_A = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_rec_TYPE_A], -1) | |
sentence_gen_lstm2_in_TYPE_A = sentence_gen_lstm2_in_TYPE_A.unsqueeze(0) | |
sentence_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_TYPE_A) | |
sentence_gen_out_TYPE_A = sentence_gen_out_TYPE_A.squeeze(0) | |
mdn_out_TYPE_A = self.gen_state_fc2(sentence_gen_out_TYPE_A) | |
eos_TYPE_A = mdn_out_TYPE_A[:,0:1] | |
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3 | |
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3 | |
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A) | |
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A) | |
term_out_TYPE_A = self.term_fc1(sentence_gen_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A) | |
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A) | |
gaussian_TYPE_A = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A) | |
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5) | |
TYPE_A_sentence_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), sentence_level_target_term) | |
TYPE_A_sentence_loc_loss = torch.mean(loss_gaussian_TYPE_A) | |
TYPE_A_sentence_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), sentence_level_target_eos) | |
TYPE_A_sentence_termination_loss.append(TYPE_A_sentence_term_loss) | |
TYPE_A_sentence_loc_reconstruct_loss.append(TYPE_A_sentence_loc_loss) | |
TYPE_A_sentence_touch_reconstruct_loss.append(TYPE_A_sentence_touch_loss) | |
# TYPE B | |
if self.TYPE_B: | |
unique_char_matrix_1 = sentence_unique_char_matrices_1[sentence_batch_id] | |
unique_char_matrices_1 = torch.stack(unique_char_matrix_1) | |
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1) | |
# sentence_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, sentence_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1) | |
sentence_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1, | |
sentence_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1) | |
sentence_W_c_TYPE_B_RAW = sentence_W_c_TYPE_B_RAW.unsqueeze(0) | |
sentence_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(sentence_W_c_TYPE_B_RAW) | |
sentence_Wc_rec_TYPE_B = sentence_Wc_rec_TYPE_B.squeeze(0) | |
sentence_Wcs_reconstruct_TYPE_B.append(sentence_Wc_rec_TYPE_B) | |
sentence_W_lstm_in_TYPE_B = [] | |
curr_id = 0 | |
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]): | |
sentence_W_lstm_in_TYPE_B.append(sentence_Wc_rec_TYPE_B[curr_id]) | |
if i in sentence_SPLITS[sentence_batch_id]: | |
curr_id += 1 | |
sentence_Wc_t_rec_TYPE_B = torch.stack(sentence_W_lstm_in_TYPE_B) | |
sentence_gen_lstm2_in_TYPE_B = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_rec_TYPE_B], -1) | |
sentence_gen_lstm2_in_TYPE_B = sentence_gen_lstm2_in_TYPE_B.unsqueeze(0) | |
sentence_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_TYPE_B) | |
sentence_gen_out_TYPE_B = sentence_gen_out_TYPE_B.squeeze(0) | |
mdn_out_TYPE_B = self.gen_state_fc2(sentence_gen_out_TYPE_B) | |
eos_TYPE_B = mdn_out_TYPE_B[:,0:1] | |
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3 | |
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3 | |
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B) | |
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B) | |
term_out_TYPE_B = self.term_fc1(sentence_gen_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B) | |
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B) | |
gaussian_TYPE_B = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B) | |
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5) | |
TYPE_B_sentence_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), sentence_level_target_term) | |
TYPE_B_sentence_loc_loss = torch.mean(loss_gaussian_TYPE_B) | |
TYPE_B_sentence_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), sentence_level_target_eos) | |
TYPE_B_sentence_termination_loss.append(TYPE_B_sentence_term_loss) | |
TYPE_B_sentence_loc_reconstruct_loss.append(TYPE_B_sentence_loc_loss) | |
TYPE_B_sentence_touch_reconstruct_loss.append(TYPE_B_sentence_touch_loss) | |
if self.ORIGINAL: | |
ALL_ORIGINAL_sentence_termination_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_termination_loss))) | |
ALL_ORIGINAL_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_loc_reconstruct_loss))) | |
ALL_ORIGINAL_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_touch_reconstruct_loss))) | |
if self.TYPE_A: | |
ALL_TYPE_A_sentence_termination_loss.append(torch.mean(torch.stack(TYPE_A_sentence_termination_loss))) | |
ALL_TYPE_A_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_loc_reconstruct_loss))) | |
ALL_TYPE_A_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_A_sentence_WC_reconstruct_loss = [] | |
for sentence_batch_id in range(len(sentence_Wc_rec_TYPE_)): | |
sentence_Wc_ORIGINAL = sentence_Wc_rec_TYPE_[sentence_batch_id] | |
sentence_Wc_TYPE_A = sentence_Wcs_reconstruct_TYPE_A[sentence_batch_id] | |
sentence_WC_reconstruct_loss_TYPE_A = torch.mean(torch.mean(torch.mul(sentence_Wc_ORIGINAL - sentence_Wc_TYPE_A, sentence_Wc_ORIGINAL - sentence_Wc_TYPE_A), -1)) | |
TYPE_A_sentence_WC_reconstruct_loss.append(sentence_WC_reconstruct_loss_TYPE_A) | |
ALL_TYPE_A_sentence_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_WC_reconstruct_loss))) | |
if self.TYPE_B: | |
ALL_TYPE_B_sentence_termination_loss.append(torch.mean(torch.stack(TYPE_B_sentence_termination_loss))) | |
ALL_TYPE_B_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_loc_reconstruct_loss))) | |
ALL_TYPE_B_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_B_sentence_WC_reconstruct_loss = [] | |
for sentence_batch_id in range(len(sentence_Wc_rec_TYPE_)): | |
sentence_Wc_ORIGINAL = sentence_Wc_rec_TYPE_[sentence_batch_id] | |
sentence_Wc_TYPE_B = sentence_Wcs_reconstruct_TYPE_B[sentence_batch_id] | |
sentence_WC_reconstruct_loss_TYPE_B = torch.mean(torch.mean(torch.mul(sentence_Wc_ORIGINAL - sentence_Wc_TYPE_B, sentence_Wc_ORIGINAL - sentence_Wc_TYPE_B), -1)) | |
TYPE_B_sentence_WC_reconstruct_loss.append(sentence_WC_reconstruct_loss_TYPE_B) | |
ALL_TYPE_B_sentence_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_WC_reconstruct_loss))) | |
if self.word_loss: | |
user_word_level_stroke_in = word_level_stroke_in[uid] | |
user_word_level_stroke_out = word_level_stroke_out[uid] | |
user_word_level_stroke_length = word_level_stroke_length[uid] | |
user_word_level_term = word_level_term[uid] | |
user_word_level_char = word_level_char[uid] | |
user_word_level_char_length = word_level_char_length[uid] | |
word_batch_size = len(user_word_level_stroke_in) | |
word_inf_state_out = self.inf_state_fc1(user_word_level_stroke_out) | |
word_inf_state_out = self.inf_state_relu(word_inf_state_out) | |
word_inf_state_out, (c,h) = self.inf_state_lstm(word_inf_state_out) | |
word_gen_state_out = self.gen_state_fc1(user_word_level_stroke_in) | |
word_gen_state_out = self.gen_state_relu(word_gen_state_out) | |
word_gen_state_out, (c,h) = self.gen_state_lstm1(word_gen_state_out) | |
word_Ws = [] | |
word_Wc_rec_ORIGINAL = [] | |
word_SPLITS = [] | |
word_Cs_1 = [] | |
word_unique_char_matrices_1 = [] | |
W_C_ORIGINALS = [] | |
for word_batch_id in range(word_batch_size): | |
curr_seq_len = user_word_level_stroke_length[word_batch_id][0] | |
curr_char_len = user_word_level_char_length[word_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(self.device) | |
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
char_vector_1 = self.char_vec_fc_1(char_vector) | |
char_vector_1 = self.char_vec_relu_1(char_vector_1) | |
unique_char_matrices_1 = [] | |
for cid in range(len(char_vector)): | |
# Tower 1 | |
unique_char_vector_1 = char_vector_1[cid:cid+1] | |
unique_char_input_1 = unique_char_vector_1.unsqueeze(0) | |
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1) | |
unique_char_out_1 = unique_char_out_1.squeeze(0) | |
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1) | |
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1) | |
unique_char_matrices_1.append(unique_char_matrix_1) | |
# Tower 1 | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c,h) = self.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = self.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
W_C_ORIGINAL = {} | |
for i in range(curr_char_len): | |
sub_s = "".join(CHARACTERS[i] for i in user_word_level_char[word_batch_id][:i+1]) | |
W_C_ORIGINAL[sub_s] = [W_c[i]] | |
W_C_ORIGINALS.append(W_C_ORIGINAL) | |
# W = torch.bmm(char_matrix_inv, W_c.unsqueeze(2)).squeeze(-1) | |
W = torch.bmm(char_matrix_inv_1, | |
W_c.unsqueeze(2)).squeeze(-1) | |
word_Ws.append(W) | |
word_Wc_rec_ORIGINAL.append(W_c) | |
word_SPLITS.append(split_ids) | |
# word_Cs.append(char_matrix) | |
# word_unique_char_matrices.append(unique_char_matrices) | |
word_Cs_1.append(char_matrix_1) | |
word_unique_char_matrices_1.append(unique_char_matrices_1) | |
word_Ws_stacked = torch.cat(word_Ws, 0) | |
word_Ws_reshaped = word_Ws_stacked.view([-1,self.weight_dim]) | |
word_W_mean = word_Ws_reshaped.mean(0) | |
word_Ws_reshaped_mean_repeat = word_W_mean.repeat(word_Ws_reshaped.size(0),1) | |
word_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(word_Ws_reshaped_mean_repeat - word_Ws_reshaped, word_Ws_reshaped_mean_repeat - word_Ws_reshaped), -1)) | |
ALL_word_W_consistency_loss.append(word_Ws_consistency_loss) | |
# word | |
ORIGINAL_word_termination_loss = [] | |
ORIGINAL_word_loc_reconstruct_loss = [] | |
ORIGINAL_word_touch_reconstruct_loss = [] | |
TYPE_A_word_termination_loss = [] | |
TYPE_A_word_loc_reconstruct_loss = [] | |
TYPE_A_word_touch_reconstruct_loss = [] | |
TYPE_B_word_termination_loss = [] | |
TYPE_B_word_loc_reconstruct_loss = [] | |
TYPE_B_word_touch_reconstruct_loss = [] | |
TYPE_C_word_termination_loss = [] | |
TYPE_C_word_loc_reconstruct_loss = [] | |
TYPE_C_word_touch_reconstruct_loss = [] | |
TYPE_D_word_termination_loss = [] | |
TYPE_D_word_loc_reconstruct_loss = [] | |
TYPE_D_word_touch_reconstruct_loss = [] | |
word_Wcs_reconstruct_TYPE_A = [] | |
word_Wcs_reconstruct_TYPE_B = [] | |
word_Wcs_reconstruct_TYPE_C = [] | |
word_Wcs_reconstruct_TYPE_D = [] | |
# segment | |
ALL_segment_W_consistency_loss = [] | |
ALL_ORIGINAL_segment_termination_loss = [] | |
ALL_ORIGINAL_segment_loc_reconstruct_loss = [] | |
ALL_ORIGINAL_segment_touch_reconstruct_loss = [] | |
ALL_TYPE_A_segment_termination_loss = [] | |
ALL_TYPE_A_segment_loc_reconstruct_loss = [] | |
ALL_TYPE_A_segment_touch_reconstruct_loss = [] | |
ALL_TYPE_A_segment_WC_reconstruct_loss = [] | |
ALL_TYPE_B_segment_termination_loss = [] | |
ALL_TYPE_B_segment_loc_reconstruct_loss = [] | |
ALL_TYPE_B_segment_touch_reconstruct_loss = [] | |
ALL_TYPE_B_segment_WC_reconstruct_loss = [] | |
ALL_segment_Wcs_reconstruct_TYPE_A = [] | |
ALL_segment_Wcs_reconstruct_TYPE_B = [] | |
W_C_SEGMENTS = [] | |
W_C_UNIQUES = [] | |
for word_batch_id in range(word_batch_size): | |
word_level_gen_encoded = word_gen_state_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]] | |
word_level_target_eos = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,2] | |
word_level_target_x = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,0:1] | |
word_level_target_y = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,1:2] | |
word_level_target_term = user_word_level_term[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]] | |
# ORIGINAL | |
if self.ORIGINAL: | |
word_W_lstm_in_ORIGINAL = [] | |
curr_id = 0 | |
for i in range(user_word_level_stroke_length[word_batch_id][0]): | |
word_W_lstm_in_ORIGINAL.append(word_Wc_rec_ORIGINAL[word_batch_id][curr_id]) | |
if i in word_SPLITS[word_batch_id]: | |
curr_id += 1 | |
word_W_lstm_in_ORIGINAL = torch.stack(word_W_lstm_in_ORIGINAL) | |
word_Wc_t_ORIGINAL = word_W_lstm_in_ORIGINAL | |
word_gen_lstm2_in_ORIGINAL = torch.cat([word_level_gen_encoded, word_Wc_t_ORIGINAL], -1) | |
word_gen_lstm2_in_ORIGINAL = word_gen_lstm2_in_ORIGINAL.unsqueeze(0) | |
word_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(word_gen_lstm2_in_ORIGINAL) | |
word_gen_out_ORIGINAL = word_gen_out_ORIGINAL.squeeze(0) | |
mdn_out_ORIGINAL = self.gen_state_fc2(word_gen_out_ORIGINAL) | |
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1] | |
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1) | |
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3 | |
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3 | |
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL) | |
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc1(word_gen_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL) | |
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL) | |
gaussian_ORIGINAL = gaussian_2d(word_level_target_x, word_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL) | |
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5) | |
ORIGINAL_word_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), word_level_target_term) | |
ORIGINAL_word_loc_loss = torch.mean(loss_gaussian_ORIGINAL) | |
ORIGINAL_word_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), word_level_target_eos) | |
ORIGINAL_word_termination_loss.append(ORIGINAL_word_term_loss) | |
ORIGINAL_word_loc_reconstruct_loss.append(ORIGINAL_word_loc_loss) | |
ORIGINAL_word_touch_reconstruct_loss.append(ORIGINAL_word_touch_loss) | |
# TYPE A | |
if self.TYPE_A: | |
word_C1 = word_Cs_1[word_batch_id] | |
word_Wc_rec_TYPE_A = torch.bmm(word_C1, | |
word_W_mean.repeat(word_C1.size(0),1).unsqueeze(2)).squeeze(-1) | |
word_Wcs_reconstruct_TYPE_A.append(word_Wc_rec_TYPE_A) | |
word_W_lstm_in_TYPE_A = [] | |
curr_id = 0 | |
for i in range(user_word_level_stroke_length[word_batch_id][0]): | |
word_W_lstm_in_TYPE_A.append(word_Wc_rec_TYPE_A[curr_id]) | |
if i in word_SPLITS[word_batch_id]: | |
curr_id += 1 | |
word_Wc_t_rec_TYPE_A = torch.stack(word_W_lstm_in_TYPE_A) | |
word_gen_lstm2_in_TYPE_A = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_A], -1) | |
word_gen_lstm2_in_TYPE_A = word_gen_lstm2_in_TYPE_A.unsqueeze(0) | |
word_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_A) | |
word_gen_out_TYPE_A = word_gen_out_TYPE_A.squeeze(0) | |
mdn_out_TYPE_A = self.gen_state_fc2(word_gen_out_TYPE_A) | |
eos_TYPE_A = mdn_out_TYPE_A[:,0:1] | |
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3 | |
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3 | |
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A) | |
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A) | |
term_out_TYPE_A = self.term_fc1(word_gen_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A) | |
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A) | |
gaussian_TYPE_A = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A) | |
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5) | |
TYPE_A_word_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), word_level_target_term) | |
TYPE_A_word_loc_loss = torch.mean(loss_gaussian_TYPE_A) | |
TYPE_A_word_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), word_level_target_eos) | |
TYPE_A_word_termination_loss.append(TYPE_A_word_term_loss) | |
TYPE_A_word_loc_reconstruct_loss.append(TYPE_A_word_loc_loss) | |
TYPE_A_word_touch_reconstruct_loss.append(TYPE_A_word_touch_loss) | |
# TYPE B | |
if self.TYPE_B: | |
unique_char_matrix_1 = word_unique_char_matrices_1[word_batch_id] | |
unique_char_matrices_1 = torch.stack(unique_char_matrix_1) | |
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1) | |
# word_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, word_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1) | |
word_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1, | |
word_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1) | |
word_W_c_TYPE_B_RAW = word_W_c_TYPE_B_RAW.unsqueeze(0) | |
word_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(word_W_c_TYPE_B_RAW) | |
word_Wc_rec_TYPE_B = word_Wc_rec_TYPE_B.squeeze(0) | |
word_Wcs_reconstruct_TYPE_B.append(word_Wc_rec_TYPE_B) | |
word_W_lstm_in_TYPE_B = [] | |
curr_id = 0 | |
for i in range(user_word_level_stroke_length[word_batch_id][0]): | |
word_W_lstm_in_TYPE_B.append(word_Wc_rec_TYPE_B[curr_id]) | |
if i in word_SPLITS[word_batch_id]: | |
curr_id += 1 | |
word_Wc_t_rec_TYPE_B = torch.stack(word_W_lstm_in_TYPE_B) | |
word_gen_lstm2_in_TYPE_B = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_B], -1) | |
word_gen_lstm2_in_TYPE_B = word_gen_lstm2_in_TYPE_B.unsqueeze(0) | |
word_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_B) | |
word_gen_out_TYPE_B = word_gen_out_TYPE_B.squeeze(0) | |
mdn_out_TYPE_B = self.gen_state_fc2(word_gen_out_TYPE_B) | |
eos_TYPE_B = mdn_out_TYPE_B[:,0:1] | |
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3 | |
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3 | |
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B) | |
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B) | |
term_out_TYPE_B = self.term_fc1(word_gen_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B) | |
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B) | |
gaussian_TYPE_B = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B) | |
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5) | |
TYPE_B_word_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), word_level_target_term) | |
TYPE_B_word_loc_loss = torch.mean(loss_gaussian_TYPE_B) | |
TYPE_B_word_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), word_level_target_eos) | |
TYPE_B_word_termination_loss.append(TYPE_B_word_term_loss) | |
TYPE_B_word_loc_reconstruct_loss.append(TYPE_B_word_loc_loss) | |
TYPE_B_word_touch_reconstruct_loss.append(TYPE_B_word_touch_loss) | |
# TYPE C | |
# if self.TYPE_C: | |
user_segment_level_stroke_in = segment_level_stroke_in[uid][word_batch_id] | |
user_segment_level_stroke_out = segment_level_stroke_out[uid][word_batch_id] | |
user_segment_level_stroke_length = segment_level_stroke_length[uid][word_batch_id] | |
user_segment_level_term = segment_level_term[uid][word_batch_id] | |
user_segment_level_char = segment_level_char[uid][word_batch_id] | |
user_segment_level_char_length = segment_level_char_length[uid][word_batch_id] | |
segment_batch_size = len(user_segment_level_stroke_in) | |
segment_inf_state_out = self.inf_state_fc1(user_segment_level_stroke_out) | |
segment_inf_state_out = self.inf_state_relu(segment_inf_state_out) | |
segment_inf_state_out, (c,h) = self.inf_state_lstm(segment_inf_state_out) | |
segment_gen_state_out = self.gen_state_fc1(user_segment_level_stroke_in) | |
segment_gen_state_out = self.gen_state_relu(segment_gen_state_out) | |
segment_gen_state_out, (c,h) = self.gen_state_lstm1(segment_gen_state_out) | |
segment_Ws = [] | |
segment_Wc_rec_ORIGINAL = [] | |
segment_SPLITS = [] | |
segment_Cs_1 = [] | |
segment_unique_char_matrices_1 = [] | |
W_C_SEGMENT = {} | |
for segment_batch_id in range(segment_batch_size): | |
curr_seq_len = user_segment_level_stroke_length[segment_batch_id][0] | |
curr_char_len = user_segment_level_char_length[segment_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_segment_level_char[segment_batch_id][:curr_char_len]].to(self.device) | |
current_term = user_segment_level_term[segment_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
char_vector_1 = self.char_vec_fc_1(char_vector) | |
char_vector_1 = self.char_vec_relu_1(char_vector_1) | |
unique_char_matrices_1 = [] | |
for cid in range(len(char_vector)): | |
# Tower 1 | |
unique_char_vector_1 = char_vector_1[cid:cid+1] | |
unique_char_input_1 = unique_char_vector_1.unsqueeze(0) | |
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1) | |
unique_char_out_1 = unique_char_out_1.squeeze(0) | |
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1) | |
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1) | |
unique_char_matrices_1.append(unique_char_matrix_1) | |
# Tower 1 | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c,h) = self.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = self.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = segment_inf_state_out[segment_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
for i in range(curr_char_len): | |
sub_s = "".join(CHARACTERS[i] for i in user_segment_level_char[segment_batch_id][:i+1]) | |
if sub_s in W_C_SEGMENT: | |
W_C_SEGMENT[sub_s].append(W_c[i]) | |
else: | |
W_C_SEGMENT[sub_s] = [W_c[i]] | |
W = torch.bmm(char_matrix_inv_1, | |
W_c.unsqueeze(2)).squeeze(-1) | |
segment_Ws.append(W) | |
segment_Wc_rec_ORIGINAL.append(W_c) | |
segment_SPLITS.append(split_ids) | |
segment_Cs_1.append(char_matrix_1) | |
segment_unique_char_matrices_1.append(unique_char_matrices_1) | |
W_C_SEGMENTS.append(W_C_SEGMENT) | |
if self.segment_loss: | |
segment_Ws_stacked = torch.cat(segment_Ws, 0) | |
segment_Ws_reshaped = segment_Ws_stacked.view([-1,self.weight_dim]) | |
segment_W_mean = segment_Ws_reshaped.mean(0) | |
segment_Ws_reshaped_mean_repeat = segment_W_mean.repeat(segment_Ws_reshaped.size(0),1) | |
segment_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(segment_Ws_reshaped_mean_repeat - segment_Ws_reshaped, segment_Ws_reshaped_mean_repeat - segment_Ws_reshaped), -1)) | |
ALL_segment_W_consistency_loss.append(segment_Ws_consistency_loss) | |
ORIGINAL_segment_termination_loss = [] | |
ORIGINAL_segment_loc_reconstruct_loss = [] | |
ORIGINAL_segment_touch_reconstruct_loss = [] | |
TYPE_A_segment_termination_loss = [] | |
TYPE_A_segment_loc_reconstruct_loss = [] | |
TYPE_A_segment_touch_reconstruct_loss = [] | |
TYPE_B_segment_termination_loss = [] | |
TYPE_B_segment_loc_reconstruct_loss = [] | |
TYPE_B_segment_touch_reconstruct_loss = [] | |
segment_Wcs_reconstruct_TYPE_A = [] | |
segment_Wcs_reconstruct_TYPE_B = [] | |
for segment_batch_id in range(segment_batch_size): | |
segment_level_gen_encoded = segment_gen_state_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]] | |
segment_level_target_eos = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,2] | |
segment_level_target_x = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,0:1] | |
segment_level_target_y = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,1:2] | |
segment_level_target_term = user_segment_level_term[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]] | |
if self.ORIGINAL: | |
segment_W_lstm_in_ORIGINAL = [] | |
curr_id = 0 | |
for i in range(user_segment_level_stroke_length[segment_batch_id][0]): | |
segment_W_lstm_in_ORIGINAL.append(segment_Wc_rec_ORIGINAL[segment_batch_id][curr_id]) | |
if i in segment_SPLITS[segment_batch_id]: | |
curr_id += 1 | |
segment_W_lstm_in_ORIGINAL = torch.stack(segment_W_lstm_in_ORIGINAL) | |
segment_Wc_t_ORIGINAL = segment_W_lstm_in_ORIGINAL | |
segment_gen_lstm2_in_ORIGINAL = torch.cat([segment_level_gen_encoded, segment_Wc_t_ORIGINAL], -1) | |
segment_gen_lstm2_in_ORIGINAL = segment_gen_lstm2_in_ORIGINAL.unsqueeze(0) | |
segment_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_ORIGINAL) | |
segment_gen_out_ORIGINAL = segment_gen_out_ORIGINAL.squeeze(0) | |
mdn_out_ORIGINAL = self.gen_state_fc2(segment_gen_out_ORIGINAL) | |
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1] | |
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1) | |
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3 | |
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3 | |
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL) | |
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc1(segment_gen_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL) | |
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL) | |
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL) | |
gaussian_ORIGINAL = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL) | |
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5) | |
ORIGINAL_segment_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), segment_level_target_term) | |
ORIGINAL_segment_loc_loss = torch.mean(loss_gaussian_ORIGINAL) | |
ORIGINAL_segment_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), segment_level_target_eos) | |
ORIGINAL_segment_termination_loss.append(ORIGINAL_segment_term_loss) | |
ORIGINAL_segment_loc_reconstruct_loss.append(ORIGINAL_segment_loc_loss) | |
ORIGINAL_segment_touch_reconstruct_loss.append(ORIGINAL_segment_touch_loss) | |
# TYPE A | |
if self.TYPE_A: | |
segment_C1 = segment_Cs_1[segment_batch_id] | |
segment_Wc_rec_TYPE_A = torch.bmm(segment_C1, | |
segment_W_mean.repeat(segment_C1.size(0),1).unsqueeze(2)).squeeze(-1) | |
segment_Wcs_reconstruct_TYPE_A.append(segment_Wc_rec_TYPE_A) | |
segment_W_lstm_in_TYPE_A = [] | |
curr_id = 0 | |
for i in range(user_segment_level_stroke_length[segment_batch_id][0]): | |
segment_W_lstm_in_TYPE_A.append(segment_Wc_rec_TYPE_A[curr_id]) | |
if i in segment_SPLITS[segment_batch_id]: | |
curr_id += 1 | |
segment_Wc_t_rec_TYPE_A = torch.stack(segment_W_lstm_in_TYPE_A) | |
segment_gen_lstm2_in_TYPE_A = torch.cat([segment_level_gen_encoded, segment_Wc_t_rec_TYPE_A], -1) | |
segment_gen_lstm2_in_TYPE_A = segment_gen_lstm2_in_TYPE_A.unsqueeze(0) | |
segment_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_TYPE_A) | |
segment_gen_out_TYPE_A = segment_gen_out_TYPE_A.squeeze(0) | |
mdn_out_TYPE_A = self.gen_state_fc2(segment_gen_out_TYPE_A) | |
eos_TYPE_A = mdn_out_TYPE_A[:,0:1] | |
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3 | |
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3 | |
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A) | |
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A) | |
term_out_TYPE_A = self.term_fc1(segment_gen_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A) | |
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A) | |
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A) | |
gaussian_TYPE_A = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A) | |
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5) | |
TYPE_A_segment_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), segment_level_target_term) | |
TYPE_A_segment_loc_loss = torch.mean(loss_gaussian_TYPE_A) | |
TYPE_A_segment_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), segment_level_target_eos) | |
TYPE_A_segment_termination_loss.append(TYPE_A_segment_term_loss) | |
TYPE_A_segment_loc_reconstruct_loss.append(TYPE_A_segment_loc_loss) | |
TYPE_A_segment_touch_reconstruct_loss.append(TYPE_A_segment_touch_loss) | |
# TYPE B | |
if self.TYPE_B: | |
unique_char_matrix_1 = segment_unique_char_matrices_1[segment_batch_id] | |
unique_char_matrices_1 = torch.stack(unique_char_matrix_1) | |
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1) | |
# segment_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, segment_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1) | |
segment_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1, | |
segment_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1) | |
segment_W_c_TYPE_B_RAW = segment_W_c_TYPE_B_RAW.unsqueeze(0) | |
segment_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(segment_W_c_TYPE_B_RAW) | |
segment_Wc_rec_TYPE_B = segment_Wc_rec_TYPE_B.squeeze(0) | |
segment_Wcs_reconstruct_TYPE_B.append(segment_Wc_rec_TYPE_B) | |
segment_W_lstm_in_TYPE_B = [] | |
curr_id = 0 | |
for i in range(user_segment_level_stroke_length[segment_batch_id][0]): | |
segment_W_lstm_in_TYPE_B.append(segment_Wc_rec_TYPE_B[curr_id]) | |
if i in segment_SPLITS[segment_batch_id]: | |
curr_id += 1 | |
segment_Wc_t_rec_TYPE_B = torch.stack(segment_W_lstm_in_TYPE_B) | |
segment_gen_lstm2_in_TYPE_B = torch.cat([segment_level_gen_encoded, segment_Wc_t_rec_TYPE_B], -1) | |
segment_gen_lstm2_in_TYPE_B = segment_gen_lstm2_in_TYPE_B.unsqueeze(0) | |
segment_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_TYPE_B) | |
segment_gen_out_TYPE_B = segment_gen_out_TYPE_B.squeeze(0) | |
mdn_out_TYPE_B = self.gen_state_fc2(segment_gen_out_TYPE_B) | |
eos_TYPE_B = mdn_out_TYPE_B[:,0:1] | |
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3 | |
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3 | |
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B) | |
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B) | |
term_out_TYPE_B = self.term_fc1(segment_gen_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B) | |
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B) | |
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B) | |
gaussian_TYPE_B = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B) | |
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5) | |
TYPE_B_segment_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), segment_level_target_term) | |
TYPE_B_segment_loc_loss = torch.mean(loss_gaussian_TYPE_B) | |
TYPE_B_segment_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), segment_level_target_eos) | |
TYPE_B_segment_termination_loss.append(TYPE_B_segment_term_loss) | |
TYPE_B_segment_loc_reconstruct_loss.append(TYPE_B_segment_loc_loss) | |
TYPE_B_segment_touch_reconstruct_loss.append(TYPE_B_segment_touch_loss) | |
if self.ORIGINAL: | |
ALL_ORIGINAL_segment_termination_loss.append(torch.mean(torch.stack(ORIGINAL_segment_termination_loss))) | |
ALL_ORIGINAL_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_segment_loc_reconstruct_loss))) | |
ALL_ORIGINAL_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_segment_touch_reconstruct_loss))) | |
if self.TYPE_A: | |
ALL_TYPE_A_segment_termination_loss.append(torch.mean(torch.stack(TYPE_A_segment_termination_loss))) | |
ALL_TYPE_A_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_loc_reconstruct_loss))) | |
ALL_TYPE_A_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_A_segment_WC_reconstruct_loss = [] | |
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)): | |
segment_Wc_ORIGINAL = segment_Wc_rec_ORIGINAL[segment_batch_id] | |
segment_Wc_TYPE_A = segment_Wcs_reconstruct_TYPE_A[segment_batch_id] | |
segment_WC_reconstruct_loss_TYPE_A = torch.mean(torch.mean(torch.mul(segment_Wc_ORIGINAL - segment_Wc_TYPE_A, segment_Wc_ORIGINAL - segment_Wc_TYPE_A), -1)) | |
TYPE_A_segment_WC_reconstruct_loss.append(segment_WC_reconstruct_loss_TYPE_A) | |
ALL_TYPE_A_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_WC_reconstruct_loss))) | |
if self.TYPE_B: | |
ALL_TYPE_B_segment_termination_loss.append(torch.mean(torch.stack(TYPE_B_segment_termination_loss))) | |
ALL_TYPE_B_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_loc_reconstruct_loss))) | |
ALL_TYPE_B_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_B_segment_WC_reconstruct_loss = [] | |
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)): | |
segment_Wc_ORIGINAL = segment_Wc_rec_ORIGINAL[segment_batch_id] | |
segment_Wc_TYPE_B = segment_Wcs_reconstruct_TYPE_B[segment_batch_id] | |
segment_WC_reconstruct_loss_TYPE_B = torch.mean(torch.mean(torch.mul(segment_Wc_ORIGINAL - segment_Wc_TYPE_B, segment_Wc_ORIGINAL - segment_Wc_TYPE_B), -1)) | |
TYPE_B_segment_WC_reconstruct_loss.append(segment_WC_reconstruct_loss_TYPE_B) | |
ALL_TYPE_B_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_WC_reconstruct_loss))) | |
if self.TYPE_C: | |
# target | |
original_W_c = word_Wc_rec_ORIGINAL[word_batch_id] | |
word_Wc_rec_TYPE_C = [] | |
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)): | |
if segment_batch_id == 0: | |
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]: | |
word_Wc_rec_TYPE_C.append(each_segment_Wc) | |
prev_id = len(word_Wc_rec_TYPE_C) - 1 | |
else: | |
prev_original_W_c = original_W_c[prev_id] | |
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]: | |
magic_inp = torch.stack([prev_original_W_c, each_segment_Wc]) | |
magic_inp = magic_inp.unsqueeze(0) | |
type_c_out, (c,h) = self.magic_lstm(magic_inp) | |
type_c_out = type_c_out.squeeze(0) | |
word_Wc_rec_TYPE_C.append(type_c_out[-1]) | |
prev_id = len(word_Wc_rec_TYPE_C) - 1 | |
word_Wc_rec_TYPE_C = torch.stack(word_Wc_rec_TYPE_C) | |
word_Wcs_reconstruct_TYPE_C.append(word_Wc_rec_TYPE_C) | |
if len(word_Wc_rec_TYPE_C) == len(word_SPLITS[word_batch_id]): | |
word_W_lstm_in_TYPE_C = [] | |
curr_id = 0 | |
for i in range(user_word_level_stroke_length[word_batch_id][0]): | |
word_W_lstm_in_TYPE_C.append(word_Wc_rec_TYPE_C[curr_id]) | |
if i in word_SPLITS[word_batch_id]: | |
curr_id += 1 | |
word_Wc_t_rec_TYPE_C = torch.stack(word_W_lstm_in_TYPE_C) | |
word_gen_lstm2_in_TYPE_C = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_C], -1) | |
word_gen_lstm2_in_TYPE_C = word_gen_lstm2_in_TYPE_C.unsqueeze(0) | |
word_gen_out_TYPE_C, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_C) | |
word_gen_out_TYPE_C = word_gen_out_TYPE_C.squeeze(0) | |
mdn_out_TYPE_C = self.gen_state_fc2(word_gen_out_TYPE_C) | |
eos_TYPE_C = mdn_out_TYPE_C[:,0:1] | |
[mu1_TYPE_C, mu2_TYPE_C, sig1_TYPE_C, sig2_TYPE_C, rho_TYPE_C, pi_TYPE_C] = torch.split(mdn_out_TYPE_C[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_C = sig1_TYPE_C.exp() + 1e-3 | |
sig2_TYPE_C = sig2_TYPE_C.exp() + 1e-3 | |
rho_TYPE_C = self.mdn_tanh(rho_TYPE_C) | |
pi_TYPE_C = self.mdn_softmax(pi_TYPE_C) | |
term_out_TYPE_C = self.term_fc1(word_gen_out_TYPE_C) | |
term_out_TYPE_C = self.term_relu1(term_out_TYPE_C) | |
term_out_TYPE_C = self.term_fc2(term_out_TYPE_C) | |
term_out_TYPE_C = self.term_relu2(term_out_TYPE_C) | |
term_out_TYPE_C = self.term_fc3(term_out_TYPE_C) | |
term_pred_TYPE_C = self.term_sigmoid(term_out_TYPE_C) | |
gaussian_TYPE_C = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_C, mu2_TYPE_C, sig1_TYPE_C, sig2_TYPE_C, rho_TYPE_C) | |
loss_gaussian_TYPE_C = - torch.log(torch.sum(pi_TYPE_C*gaussian_TYPE_C, dim=1) + 1e-5) | |
TYPE_C_word_term_loss = self.term_bce_loss(term_out_TYPE_C.squeeze(1), word_level_target_term) | |
TYPE_C_word_loc_loss = torch.mean(loss_gaussian_TYPE_C) | |
TYPE_C_word_touch_loss = self.mdn_bce_loss(eos_TYPE_C.squeeze(1), word_level_target_eos) | |
TYPE_C_word_termination_loss.append(TYPE_C_word_term_loss) | |
TYPE_C_word_loc_reconstruct_loss.append(TYPE_C_word_loc_loss) | |
TYPE_C_word_touch_reconstruct_loss.append(TYPE_C_word_touch_loss) | |
else: | |
print ("not C") | |
if self.TYPE_D: | |
word_Wc_rec_TYPE_D = [] | |
TYPE_D_REF = [] | |
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)): | |
if segment_batch_id == 0: | |
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]: | |
word_Wc_rec_TYPE_D.append(each_segment_Wc) | |
TYPE_D_REF.append(segment_Wc_rec_ORIGINAL[segment_batch_id][-1]) | |
else: | |
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]: | |
magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0) | |
magic_inp = magic_inp.unsqueeze(0) | |
TYPE_D_out, (c,h) = self.magic_lstm(magic_inp) | |
TYPE_D_out = TYPE_D_out.squeeze(0) | |
word_Wc_rec_TYPE_D.append(TYPE_D_out[-1]) | |
TYPE_D_REF.append(segment_Wc_rec_ORIGINAL[segment_batch_id][-1]) | |
word_Wc_rec_TYPE_D = torch.stack(word_Wc_rec_TYPE_D) | |
word_Wcs_reconstruct_TYPE_D.append(word_Wc_rec_TYPE_D) | |
if len(word_Wc_rec_TYPE_D) == len(word_SPLITS[word_batch_id]): | |
word_W_lstm_in_TYPE_D = [] | |
curr_id = 0 | |
for i in range(user_word_level_stroke_length[word_batch_id][0]): | |
word_W_lstm_in_TYPE_D.append(word_Wc_rec_TYPE_D[curr_id]) | |
if i in word_SPLITS[word_batch_id]: | |
curr_id += 1 | |
word_Wc_t_rec_TYPE_D = torch.stack(word_W_lstm_in_TYPE_D) | |
word_gen_lstm2_in_TYPE_D = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_D], -1) | |
word_gen_lstm2_in_TYPE_D = word_gen_lstm2_in_TYPE_D.unsqueeze(0) | |
word_gen_out_TYPE_D, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_D) | |
word_gen_out_TYPE_D = word_gen_out_TYPE_D.squeeze(0) | |
mdn_out_TYPE_D = self.gen_state_fc2(word_gen_out_TYPE_D) | |
eos_TYPE_D = mdn_out_TYPE_D[:,0:1] | |
[mu1_TYPE_D, mu2_TYPE_D, sig1_TYPE_D, sig2_TYPE_D, rho_TYPE_D, pi_TYPE_D] = torch.split(mdn_out_TYPE_D[:,1:], self.num_mixtures, 1) | |
sig1_TYPE_D = sig1_TYPE_D.exp() + 1e-3 | |
sig2_TYPE_D = sig2_TYPE_D.exp() + 1e-3 | |
rho_TYPE_D = self.mdn_tanh(rho_TYPE_D) | |
pi_TYPE_D = self.mdn_softmax(pi_TYPE_D) | |
term_out_TYPE_D = self.term_fc1(word_gen_out_TYPE_D) | |
term_out_TYPE_D = self.term_relu1(term_out_TYPE_D) | |
term_out_TYPE_D = self.term_fc2(term_out_TYPE_D) | |
term_out_TYPE_D = self.term_relu2(term_out_TYPE_D) | |
term_out_TYPE_D = self.term_fc3(term_out_TYPE_D) | |
term_pred_TYPE_D = self.term_sigmoid(term_out_TYPE_D) | |
gaussian_TYPE_D = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_D, mu2_TYPE_D, sig1_TYPE_D, sig2_TYPE_D, rho_TYPE_D) | |
loss_gaussian_TYPE_D = - torch.log(torch.sum(pi_TYPE_D*gaussian_TYPE_D, dim=1) + 1e-5) | |
TYPE_D_word_term_loss = self.term_bce_loss(term_out_TYPE_D.squeeze(1), word_level_target_term) | |
TYPE_D_word_loc_loss = torch.mean(loss_gaussian_TYPE_D) | |
TYPE_D_word_touch_loss = self.mdn_bce_loss(eos_TYPE_D.squeeze(1), word_level_target_eos) | |
TYPE_D_word_termination_loss.append(TYPE_D_word_term_loss) | |
TYPE_D_word_loc_reconstruct_loss.append(TYPE_D_word_loc_loss) | |
TYPE_D_word_touch_reconstruct_loss.append(TYPE_D_word_touch_loss) | |
else: | |
print ("not D") | |
# word | |
if self.ORIGINAL: | |
ALL_ORIGINAL_word_termination_loss.append(torch.mean(torch.stack(ORIGINAL_word_termination_loss))) | |
ALL_ORIGINAL_word_loc_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_word_loc_reconstruct_loss))) | |
ALL_ORIGINAL_word_touch_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_word_touch_reconstruct_loss))) | |
if self.TYPE_A: | |
ALL_TYPE_A_word_termination_loss.append(torch.mean(torch.stack(TYPE_A_word_termination_loss))) | |
ALL_TYPE_A_word_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_word_loc_reconstruct_loss))) | |
ALL_TYPE_A_word_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_word_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_A_word_WC_reconstruct_loss = [] | |
for word_batch_id in range(len(word_Wc_rec_ORIGINAL)): | |
word_Wc_ORIGINAL = word_Wc_rec_ORIGINAL[word_batch_id] | |
word_Wc_TYPE_A = word_Wcs_reconstruct_TYPE_A[word_batch_id] | |
if len(word_Wc_ORIGINAL) == len(word_Wc_TYPE_A): | |
word_WC_reconstruct_loss_TYPE_A = torch.mean(torch.mean(torch.mul(word_Wc_ORIGINAL - word_Wc_TYPE_A, word_Wc_ORIGINAL - word_Wc_TYPE_A), -1)) | |
TYPE_A_word_WC_reconstruct_loss.append(word_WC_reconstruct_loss_TYPE_A) | |
if len(TYPE_A_word_WC_reconstruct_loss) > 0: | |
ALL_TYPE_A_word_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_word_WC_reconstruct_loss))) | |
if self.TYPE_B: | |
ALL_TYPE_B_word_termination_loss.append(torch.mean(torch.stack(TYPE_B_word_termination_loss))) | |
ALL_TYPE_B_word_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_word_loc_reconstruct_loss))) | |
ALL_TYPE_B_word_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_word_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_B_word_WC_reconstruct_loss = [] | |
for word_batch_id in range(len(word_Wc_rec_ORIGINAL)): | |
word_Wc_ORIGINAL = word_Wc_rec_ORIGINAL[word_batch_id] | |
word_Wc_TYPE_B = word_Wcs_reconstruct_TYPE_B[word_batch_id] | |
if len(word_Wc_ORIGINAL) == len(word_Wc_TYPE_B): | |
word_WC_reconstruct_loss_TYPE_B = torch.mean(torch.mean(torch.mul(word_Wc_ORIGINAL - word_Wc_TYPE_B, word_Wc_ORIGINAL - word_Wc_TYPE_B), -1)) | |
TYPE_B_word_WC_reconstruct_loss.append(word_WC_reconstruct_loss_TYPE_B) | |
if len(TYPE_B_word_WC_reconstruct_loss) > 0: | |
ALL_TYPE_B_word_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_word_WC_reconstruct_loss))) | |
if self.TYPE_C: | |
ALL_TYPE_C_word_termination_loss.append(torch.mean(torch.stack(TYPE_C_word_termination_loss))) | |
ALL_TYPE_C_word_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_C_word_loc_reconstruct_loss))) | |
ALL_TYPE_C_word_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_C_word_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_C_word_WC_reconstruct_loss = [] | |
for word_batch_id in range(len(word_Wc_rec_ORIGINAL)): | |
word_Wc_ORIGINAL = word_Wc_rec_ORIGINAL[word_batch_id] | |
word_Wc_TYPE_C = word_Wcs_reconstruct_TYPE_C[word_batch_id] | |
if len(word_Wc_ORIGINAL) == len(word_Wc_TYPE_C): | |
word_WC_reconstruct_loss_TYPE_C = torch.mean(torch.mean(torch.mul(word_Wc_ORIGINAL - word_Wc_TYPE_C, word_Wc_ORIGINAL - word_Wc_TYPE_C), -1)) | |
TYPE_C_word_WC_reconstruct_loss.append(word_WC_reconstruct_loss_TYPE_C) | |
if len(TYPE_C_word_WC_reconstruct_loss) > 0: | |
ALL_TYPE_C_word_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_C_word_WC_reconstruct_loss))) | |
if self.TYPE_D: | |
ALL_TYPE_D_word_termination_loss.append(torch.mean(torch.stack(TYPE_D_word_termination_loss))) | |
ALL_TYPE_D_word_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_D_word_loc_reconstruct_loss))) | |
ALL_TYPE_D_word_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_D_word_touch_reconstruct_loss))) | |
if self.REC: | |
TYPE_D_word_WC_reconstruct_loss = [] | |
for word_batch_id in range(len(word_Wc_rec_ORIGINAL)): | |
word_Wc_ORIGINAL = word_Wc_rec_ORIGINAL[word_batch_id] | |
word_Wc_TYPE_D = word_Wcs_reconstruct_TYPE_D[word_batch_id] | |
if len(word_Wc_ORIGINAL) == len(word_Wc_TYPE_D): | |
word_WC_reconstruct_loss_TYPE_D = torch.mean(torch.mean(torch.mul(word_Wc_ORIGINAL - word_Wc_TYPE_D, word_Wc_ORIGINAL - word_Wc_TYPE_D), -1)) | |
TYPE_D_word_WC_reconstruct_loss.append(word_WC_reconstruct_loss_TYPE_D) | |
if len(TYPE_D_word_WC_reconstruct_loss) > 0: | |
ALL_TYPE_D_word_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_D_word_WC_reconstruct_loss))) | |
# segment | |
if self.segment_loss: | |
SUPER_ALL_segment_W_consistency_loss.append(torch.mean(torch.stack(ALL_segment_W_consistency_loss))) | |
if self.ORIGINAL: | |
SUPER_ALL_ORIGINAL_segment_termination_loss.append(torch.mean(torch.stack(ALL_ORIGINAL_segment_termination_loss))) | |
SUPER_ALL_ORIGINAL_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(ALL_ORIGINAL_segment_loc_reconstruct_loss))) | |
SUPER_ALL_ORIGINAL_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(ALL_ORIGINAL_segment_touch_reconstruct_loss))) | |
if self.TYPE_A: | |
SUPER_ALL_TYPE_A_segment_termination_loss.append(torch.mean(torch.stack(ALL_TYPE_A_segment_termination_loss))) | |
SUPER_ALL_TYPE_A_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_A_segment_loc_reconstruct_loss))) | |
SUPER_ALL_TYPE_A_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_A_segment_touch_reconstruct_loss))) | |
if self.REC: | |
SUPER_ALL_TYPE_A_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_A_segment_WC_reconstruct_loss))) | |
if self.TYPE_B: | |
SUPER_ALL_TYPE_B_segment_termination_loss.append(torch.mean(torch.stack(ALL_TYPE_B_segment_termination_loss))) | |
SUPER_ALL_TYPE_B_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_B_segment_loc_reconstruct_loss))) | |
SUPER_ALL_TYPE_B_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_B_segment_touch_reconstruct_loss))) | |
if self.REC: | |
SUPER_ALL_TYPE_B_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(ALL_TYPE_B_segment_WC_reconstruct_loss))) | |
total_sentence_loss = 0 | |
sentence_losses = [] | |
if self.sentence_loss: | |
mean_ORIGINAL_sentence_termination_loss = 0 | |
mean_ORIGINAL_sentence_loc_reconstruct_loss = 0 | |
mean_ORIGINAL_sentence_touch_reconstruct_loss = 0 | |
mean_TYPE_A_sentence_termination_loss = 0 | |
mean_TYPE_A_sentence_loc_reconstruct_loss = 0 | |
mean_TYPE_A_sentence_touch_reconstruct_loss = 0 | |
mean_TYPE_B_sentence_termination_loss = 0 | |
mean_TYPE_B_sentence_loc_reconstruct_loss = 0 | |
mean_TYPE_B_sentence_touch_reconstruct_loss = 0 | |
mean_TYPE_A_sentence_WC_reconstruct_loss = 0 | |
mean_TYPE_B_sentence_WC_reconstruct_loss = 0 | |
mean_sentence_W_consistency_loss = torch.mean(torch.stack(ALL_sentence_W_consistency_loss)) | |
if self.ORIGINAL: | |
mean_ORIGINAL_sentence_termination_loss = torch.mean(torch.stack(ALL_ORIGINAL_sentence_termination_loss)) | |
mean_ORIGINAL_sentence_loc_reconstruct_loss = torch.mean(torch.stack(ALL_ORIGINAL_sentence_loc_reconstruct_loss)) | |
mean_ORIGINAL_sentence_touch_reconstruct_loss = torch.mean(torch.stack(ALL_ORIGINAL_sentence_touch_reconstruct_loss)) | |
if self.TYPE_A: | |
mean_TYPE_A_sentence_termination_loss = torch.mean(torch.stack(ALL_TYPE_A_sentence_termination_loss)) | |
mean_TYPE_A_sentence_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_sentence_loc_reconstruct_loss)) | |
mean_TYPE_A_sentence_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_sentence_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_A_sentence_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_sentence_WC_reconstruct_loss)) | |
if self.TYPE_B: | |
mean_TYPE_B_sentence_termination_loss = torch.mean(torch.stack(ALL_TYPE_B_sentence_termination_loss)) | |
mean_TYPE_B_sentence_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_sentence_loc_reconstruct_loss)) | |
mean_TYPE_B_sentence_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_sentence_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_B_sentence_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_sentence_WC_reconstruct_loss)) | |
total_sentence_loss = mean_sentence_W_consistency_loss + mean_ORIGINAL_sentence_termination_loss + mean_ORIGINAL_sentence_loc_reconstruct_loss + mean_ORIGINAL_sentence_touch_reconstruct_loss + mean_TYPE_A_sentence_termination_loss + mean_TYPE_A_sentence_loc_reconstruct_loss + mean_TYPE_A_sentence_touch_reconstruct_loss + mean_TYPE_B_sentence_termination_loss + mean_TYPE_B_sentence_loc_reconstruct_loss + mean_TYPE_B_sentence_touch_reconstruct_loss + mean_TYPE_A_sentence_WC_reconstruct_loss + mean_TYPE_B_sentence_WC_reconstruct_loss | |
sentence_losses = [total_sentence_loss, mean_sentence_W_consistency_loss, mean_ORIGINAL_sentence_termination_loss, mean_ORIGINAL_sentence_loc_reconstruct_loss, mean_ORIGINAL_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_termination_loss, mean_TYPE_A_sentence_loc_reconstruct_loss, mean_TYPE_A_sentence_touch_reconstruct_loss, mean_TYPE_B_sentence_termination_loss, mean_TYPE_B_sentence_loc_reconstruct_loss, mean_TYPE_B_sentence_touch_reconstruct_loss, mean_TYPE_A_sentence_WC_reconstruct_loss, mean_TYPE_B_sentence_WC_reconstruct_loss] | |
total_word_loss = 0 | |
word_losses = [] | |
if self.word_loss: | |
mean_ORIGINAL_word_termination_loss = 0 | |
mean_ORIGINAL_word_loc_reconstruct_loss = 0 | |
mean_ORIGINAL_word_touch_reconstruct_loss = 0 | |
mean_TYPE_A_word_termination_loss = 0 | |
mean_TYPE_A_word_loc_reconstruct_loss = 0 | |
mean_TYPE_A_word_touch_reconstruct_loss = 0 | |
mean_TYPE_B_word_termination_loss = 0 | |
mean_TYPE_B_word_loc_reconstruct_loss = 0 | |
mean_TYPE_B_word_touch_reconstruct_loss = 0 | |
mean_TYPE_C_word_termination_loss = 0 | |
mean_TYPE_C_word_loc_reconstruct_loss = 0 | |
mean_TYPE_C_word_touch_reconstruct_loss = 0 | |
mean_TYPE_D_word_termination_loss = 0 | |
mean_TYPE_D_word_loc_reconstruct_loss = 0 | |
mean_TYPE_D_word_touch_reconstruct_loss = 0 | |
mean_TYPE_A_word_WC_reconstruct_loss = 0 | |
mean_TYPE_B_word_WC_reconstruct_loss = 0 | |
mean_TYPE_C_word_WC_reconstruct_loss = 0 | |
mean_TYPE_D_word_WC_reconstruct_loss = 0 | |
mean_word_W_consistency_loss = torch.mean(torch.stack(ALL_word_W_consistency_loss)) | |
if self.ORIGINAL: | |
mean_ORIGINAL_word_termination_loss = torch.mean(torch.stack(ALL_ORIGINAL_word_termination_loss)) | |
mean_ORIGINAL_word_loc_reconstruct_loss = torch.mean(torch.stack(ALL_ORIGINAL_word_loc_reconstruct_loss)) | |
mean_ORIGINAL_word_touch_reconstruct_loss = torch.mean(torch.stack(ALL_ORIGINAL_word_touch_reconstruct_loss)) | |
if self.TYPE_A: | |
mean_TYPE_A_word_termination_loss = torch.mean(torch.stack(ALL_TYPE_A_word_termination_loss)) | |
mean_TYPE_A_word_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_word_loc_reconstruct_loss)) | |
mean_TYPE_A_word_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_word_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_A_word_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_A_word_WC_reconstruct_loss)) | |
if self.TYPE_B: | |
mean_TYPE_B_word_termination_loss = torch.mean(torch.stack(ALL_TYPE_B_word_termination_loss)) | |
mean_TYPE_B_word_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_word_loc_reconstruct_loss)) | |
mean_TYPE_B_word_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_word_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_B_word_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_B_word_WC_reconstruct_loss)) | |
if self.TYPE_C: | |
mean_TYPE_C_word_termination_loss = torch.mean(torch.stack(ALL_TYPE_C_word_termination_loss)) | |
mean_TYPE_C_word_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_C_word_loc_reconstruct_loss)) | |
mean_TYPE_C_word_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_C_word_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_C_word_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_C_word_WC_reconstruct_loss)) | |
if self.TYPE_D: | |
mean_TYPE_D_word_termination_loss = torch.mean(torch.stack(ALL_TYPE_D_word_termination_loss)) | |
mean_TYPE_D_word_loc_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_D_word_loc_reconstruct_loss)) | |
mean_TYPE_D_word_touch_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_D_word_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_D_word_WC_reconstruct_loss = torch.mean(torch.stack(ALL_TYPE_D_word_WC_reconstruct_loss)) | |
total_word_loss = mean_word_W_consistency_loss + mean_ORIGINAL_word_termination_loss + mean_ORIGINAL_word_loc_reconstruct_loss + mean_ORIGINAL_word_touch_reconstruct_loss + mean_TYPE_A_word_termination_loss + mean_TYPE_A_word_loc_reconstruct_loss + mean_TYPE_A_word_touch_reconstruct_loss + mean_TYPE_B_word_termination_loss + mean_TYPE_B_word_loc_reconstruct_loss + mean_TYPE_B_word_touch_reconstruct_loss + mean_TYPE_C_word_termination_loss + mean_TYPE_C_word_loc_reconstruct_loss + mean_TYPE_C_word_touch_reconstruct_loss + mean_TYPE_D_word_termination_loss + mean_TYPE_D_word_loc_reconstruct_loss + mean_TYPE_D_word_touch_reconstruct_loss + mean_TYPE_A_word_WC_reconstruct_loss + mean_TYPE_B_word_WC_reconstruct_loss + mean_TYPE_C_word_WC_reconstruct_loss + mean_TYPE_D_word_WC_reconstruct_loss | |
word_losses = [total_word_loss, mean_word_W_consistency_loss, mean_ORIGINAL_word_termination_loss, mean_ORIGINAL_word_loc_reconstruct_loss, mean_ORIGINAL_word_touch_reconstruct_loss, mean_TYPE_A_word_termination_loss, mean_TYPE_A_word_loc_reconstruct_loss, mean_TYPE_A_word_touch_reconstruct_loss, mean_TYPE_B_word_termination_loss, mean_TYPE_B_word_loc_reconstruct_loss, mean_TYPE_B_word_touch_reconstruct_loss, mean_TYPE_C_word_termination_loss, mean_TYPE_C_word_loc_reconstruct_loss, mean_TYPE_C_word_touch_reconstruct_loss, mean_TYPE_D_word_termination_loss, mean_TYPE_D_word_loc_reconstruct_loss, mean_TYPE_D_word_touch_reconstruct_loss, mean_TYPE_A_word_WC_reconstruct_loss, mean_TYPE_B_word_WC_reconstruct_loss, mean_TYPE_C_word_WC_reconstruct_loss, mean_TYPE_D_word_WC_reconstruct_loss] | |
total_segment_loss = 0 | |
segment_losses = [] | |
if self.segment_loss: | |
mean_segment_W_consistency_loss = torch.mean(torch.stack(SUPER_ALL_segment_W_consistency_loss)) | |
mean_ORIGINAL_segment_termination_loss = 0 | |
mean_ORIGINAL_segment_loc_reconstruct_loss = 0 | |
mean_ORIGINAL_segment_touch_reconstruct_loss = 0 | |
mean_TYPE_A_segment_termination_loss = 0 | |
mean_TYPE_A_segment_loc_reconstruct_loss = 0 | |
mean_TYPE_A_segment_touch_reconstruct_loss = 0 | |
mean_TYPE_B_segment_termination_loss = 0 | |
mean_TYPE_B_segment_loc_reconstruct_loss = 0 | |
mean_TYPE_B_segment_touch_reconstruct_loss = 0 | |
mean_TYPE_A_segment_WC_reconstruct_loss = 0 | |
mean_TYPE_B_segment_WC_reconstruct_loss = 0 | |
if self.ORIGINAL: | |
mean_ORIGINAL_segment_termination_loss = torch.mean(torch.stack(SUPER_ALL_ORIGINAL_segment_termination_loss)) | |
mean_ORIGINAL_segment_loc_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_ORIGINAL_segment_loc_reconstruct_loss)) | |
mean_ORIGINAL_segment_touch_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_ORIGINAL_segment_touch_reconstruct_loss)) | |
if self.TYPE_A: | |
mean_TYPE_A_segment_termination_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_A_segment_termination_loss)) | |
mean_TYPE_A_segment_loc_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_A_segment_loc_reconstruct_loss)) | |
mean_TYPE_A_segment_touch_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_A_segment_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_A_segment_WC_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_A_segment_WC_reconstruct_loss)) | |
if self.TYPE_B: | |
mean_TYPE_B_segment_termination_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_B_segment_termination_loss)) | |
mean_TYPE_B_segment_loc_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_B_segment_loc_reconstruct_loss)) | |
mean_TYPE_B_segment_touch_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_B_segment_touch_reconstruct_loss)) | |
if self.REC: | |
mean_TYPE_B_segment_WC_reconstruct_loss = torch.mean(torch.stack(SUPER_ALL_TYPE_B_segment_WC_reconstruct_loss)) | |
total_segment_loss = mean_segment_W_consistency_loss + mean_ORIGINAL_segment_termination_loss + mean_ORIGINAL_segment_loc_reconstruct_loss + mean_ORIGINAL_segment_touch_reconstruct_loss + mean_TYPE_A_segment_termination_loss + mean_TYPE_A_segment_loc_reconstruct_loss + mean_TYPE_A_segment_touch_reconstruct_loss + mean_TYPE_B_segment_termination_loss + mean_TYPE_B_segment_loc_reconstruct_loss + mean_TYPE_B_segment_touch_reconstruct_loss + mean_TYPE_A_segment_WC_reconstruct_loss + mean_TYPE_B_segment_WC_reconstruct_loss | |
segment_losses = [total_segment_loss, mean_segment_W_consistency_loss, mean_ORIGINAL_segment_termination_loss, mean_ORIGINAL_segment_loc_reconstruct_loss, mean_ORIGINAL_segment_touch_reconstruct_loss, mean_TYPE_A_segment_termination_loss, mean_TYPE_A_segment_loc_reconstruct_loss, mean_TYPE_A_segment_touch_reconstruct_loss, mean_TYPE_B_segment_termination_loss, mean_TYPE_B_segment_loc_reconstruct_loss, mean_TYPE_B_segment_touch_reconstruct_loss, mean_TYPE_A_segment_WC_reconstruct_loss, mean_TYPE_B_segment_WC_reconstruct_loss] | |
total_loss = total_sentence_loss + total_word_loss + total_segment_loss | |
return total_loss, sentence_losses, word_losses, segment_losses | |
def sample(self, inputs): | |
[ word_level_stroke_in, word_level_stroke_out, word_level_stroke_length, | |
word_level_term, word_level_char, word_level_char_length, segment_level_stroke_in, | |
segment_level_stroke_out, segment_level_stroke_length, segment_level_term, | |
segment_level_char, segment_level_char_length ] = inputs | |
word_inf_state_out = self.inf_state_fc1(word_level_stroke_out[0]) | |
word_inf_state_out = self.inf_state_relu(word_inf_state_out) | |
word_inf_state_out, (c,h) = self.inf_state_lstm(word_inf_state_out) | |
user_word_level_char = word_level_char[0] | |
user_word_level_term = word_level_term[0] | |
raw_Ws = [] | |
original_Wc = [] | |
word_batch_id = 0 | |
# ORIGINAL | |
curr_seq_len = word_level_stroke_length[0][word_batch_id][0] | |
curr_char_len = word_level_char_length[0][word_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(self.device) | |
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
# char_vector = self.char_vec_fc(char_vector) | |
# char_vector = self.char_vec_relu(char_vector) | |
char_vector_1 = self.char_vec_fc_1(char_vector) | |
char_vector_1 = self.char_vec_relu_1(char_vector_1) | |
# unique_char_matrices = [] | |
# for cid in range(len(char_vector)): | |
# unique_char_vector = char_vector[cid:cid+1] | |
# unique_char_out = unique_char_vector.unsqueeze(0) | |
# unique_char_out, (c,h) = self.char_lstm(unique_char_out) | |
# unique_char_out = unique_char_out.squeeze(0) | |
# unique_char_out = self.char_vec_fc2(unique_char_out) | |
# unique_char_matrix = unique_char_out.view([-1,1,self.weight_dim,self.weight_dim]) | |
# unique_char_matrix = unique_char_matrix.squeeze(1) | |
# unique_char_matrices.append(unique_char_matrix) | |
unique_char_matrices_1 = [] | |
for cid in range(len(char_vector)): | |
# Tower 1 | |
unique_char_vector_1 = char_vector_1[cid:cid+1] | |
unique_char_input_1 = unique_char_vector_1.unsqueeze(0) | |
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1) | |
unique_char_out_1 = unique_char_out_1.squeeze(0) | |
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1) | |
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1) | |
unique_char_matrices_1.append(unique_char_matrix_1) | |
# Tower 1 | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c,h) = self.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = self.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
original_Wc.append(W_c) | |
W = torch.bmm(char_matrix_inv_1, | |
W_c.unsqueeze(2)).squeeze(-1) | |
user_segment_level_stroke_length = segment_level_stroke_length[0][word_batch_id] | |
user_segment_level_char_length = segment_level_char_length[0][word_batch_id] | |
user_segment_level_term = segment_level_term[0][word_batch_id] | |
user_segment_level_char = segment_level_char[0][word_batch_id] | |
user_segment_level_stroke_in = segment_level_stroke_in[0][word_batch_id] | |
user_segment_level_stroke_out = segment_level_stroke_out[0][word_batch_id] | |
segment_inf_state_out = self.inf_state_fc1(user_segment_level_stroke_out) | |
segment_inf_state_out = self.inf_state_relu(segment_inf_state_out) | |
segment_inf_state_out, (c,h) = self.inf_state_lstm(segment_inf_state_out) | |
segment_W_c = [] | |
for segment_batch_id in range(len(user_segment_level_char)): | |
curr_seq_len = user_segment_level_stroke_length[segment_batch_id][0] | |
curr_char_len = user_segment_level_char_length[segment_batch_id][0] | |
current_term = user_segment_level_term[segment_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
seg_W_c_t = segment_inf_state_out[segment_batch_id][:curr_seq_len] | |
seg_W_c = torch.stack([seg_W_c_t[i] for i in split_ids]) | |
segment_W_c.append(seg_W_c) | |
target_characters_ids = word_level_char[0][0][:word_level_char_length[0][0]] | |
target_characters = ''.join([CHARACTERS[i] for i in target_characters_ids]) | |
mean_global_W = torch.mean(W, 0) | |
TYPE_A_WC = torch.bmm(char_matrix_1, | |
mean_global_W.repeat(char_matrix_1.size(0), 1).unsqueeze(2)).squeeze(-1) | |
unique_char_matrix_1 = torch.stack(unique_char_matrices_1) | |
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1) | |
TYPE_B_WC_RAW = torch.bmm(unique_char_matrix_1, | |
mean_global_W.repeat(unique_char_matrix_1.size(0), 1).unsqueeze(2)).squeeze(-1) | |
TYPE_B_WC_RAW = TYPE_B_WC_RAW.unsqueeze(0) | |
TYPE_B_WC, (c,h) = self.magic_lstm(TYPE_B_WC_RAW) | |
TYPE_B_WC = TYPE_B_WC.squeeze(0) | |
# CC | |
TYPE_C_WC = [] | |
for segment_batch_id in range(len(segment_W_c)): | |
if segment_batch_id == 0: | |
for each_segment_Wc in segment_W_c[segment_batch_id]: | |
TYPE_C_WC.append(each_segment_Wc) | |
prev_id = len(TYPE_C_WC) - 1 | |
else: | |
prev_original_W_c = W_c[prev_id] | |
for each_segment_Wc in segment_W_c[segment_batch_id]: | |
magic_inp = torch.stack([prev_original_W_c, each_segment_Wc]) | |
magic_inp = magic_inp.unsqueeze(0) | |
type_c_out, (c,h) = self.magic_lstm(magic_inp) | |
type_c_out = type_c_out.squeeze(0) | |
TYPE_C_WC.append(type_c_out[-1]) | |
prev_id = len(TYPE_C_WC) - 1 | |
TYPE_C_WC = torch.stack(TYPE_C_WC) | |
# DD | |
TYPE_D_WC = [] | |
TYPE_D_REF = [] | |
for segment_batch_id in range(len(segment_W_c)): | |
if segment_batch_id == 0: | |
for each_segment_Wc in segment_W_c[segment_batch_id]: | |
TYPE_D_WC.append(each_segment_Wc) | |
TYPE_D_REF.append(segment_W_c[segment_batch_id][-1]) | |
else: | |
for each_segment_Wc in segment_W_c[segment_batch_id]: | |
magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0) | |
magic_inp = magic_inp.unsqueeze(0) | |
TYPE_D_out, (c,h) = self.magic_lstm(magic_inp) | |
TYPE_D_out = TYPE_D_out.squeeze(0) | |
TYPE_D_WC.append(TYPE_D_out[-1]) | |
TYPE_D_REF.append(segment_W_c[segment_batch_id][-1]) | |
TYPE_D_WC = torch.stack(TYPE_D_WC) | |
o_tc = ''.join([CHARACTERS[c] for c in word_level_char[0][0][:word_level_char_length[0][0]]]) | |
o_commands = self.sample_from_w(original_Wc[0], o_tc) | |
if len(TYPE_A_WC) == len(original_Wc[0]): | |
a_commands = self.sample_from_w(TYPE_A_WC, target_characters) | |
else: | |
a_commands = [[0,0,0]] | |
if len(TYPE_B_WC) == len(original_Wc[0]): | |
b_commands = self.sample_from_w(TYPE_B_WC, target_characters) | |
else: | |
b_commands = [[0,0,0]] | |
if len(TYPE_C_WC) == len(original_Wc[0]): | |
c_commands = self.sample_from_w(TYPE_C_WC, target_characters) | |
else: | |
c_commands = [[0,0,0]] | |
if len(TYPE_D_WC) == len(original_Wc[0]): | |
d_commands = self.sample_from_w(TYPE_D_WC, target_characters) | |
else: | |
d_commands = [[0,0,0]] | |
return [word_level_stroke_out[0][0], o_commands, a_commands, b_commands, c_commands, d_commands] | |
def sample_from_w(self, W_c_rec, target_sentence): | |
gen_input = torch.zeros([1, 1, 3]).to(self.device) | |
current_char_id_count = 0 | |
gc1 = torch.zeros([self.num_layers, 1, self.weight_dim]).to(self.device) | |
gh1 = torch.zeros([self.num_layers, 1, self.weight_dim]).to(self.device) | |
gc2 = torch.zeros([self.num_layers, 1, self.weight_dim * 2]).to(self.device) | |
gh2 = torch.zeros([self.num_layers, 1, self.weight_dim * 2]).to(self.device) | |
terms = [] | |
commands = [] | |
character_nums = 0 | |
cx, cy = 100, 150 | |
for zz in range(800): | |
W_c_t_now = W_c_rec[current_char_id_count:current_char_id_count + 1] | |
gen_state = self.gen_state_fc1(gen_input) | |
gen_state = self.gen_state_relu(gen_state) | |
gen_state, (gc1, gh1) = self.gen_state_lstm1(gen_state, (gc1, gh1)) | |
gen_encoded = gen_state.squeeze(0) | |
gen_lstm2_input = torch.cat([gen_encoded, W_c_t_now], -1) | |
gen_lstm2_input = gen_lstm2_input.view([1, 1, self.weight_dim * 2]) | |
gen_out, (gc2, gh2) = self.gen_state_lstm2(gen_lstm2_input, (gc2, gh2)) | |
gen_out = gen_out.squeeze(0) | |
mdn_out = self.gen_state_fc2(gen_out) | |
term_out = self.term_fc1(gen_out) | |
term_out = self.term_relu1(term_out) | |
term_out = self.term_fc2(term_out) | |
term_out = self.term_relu2(term_out) | |
term_out = self.term_fc3(term_out) | |
term = self.term_sigmoid(term_out) | |
eos = self.mdn_sigmoid(mdn_out[:, 0]) | |
[mu1, mu2, sig1, sig2, rho, pi] = torch.split(mdn_out[:, 1:], self.num_mixtures, 1) | |
sig1 = sig1.exp() + 1e-3 | |
sig2 = sig2.exp() + 1e-3 | |
rho = self.mdn_tanh(rho) | |
pi = self.mdn_softmax(pi) | |
mus = torch.stack([mu1, mu2], -1).squeeze() | |
pi = pi.cpu().detach().numpy() | |
mus = mus.cpu().detach().numpy() | |
rho = rho.cpu().detach().numpy()[0] | |
eos = eos.cpu().detach().numpy()[0] | |
term = term.cpu().detach().numpy()[0][0] | |
terms.append(term) | |
[dx, dy] = np.sum(pi.reshape(20, 1) * mus, 0) | |
# print (eos) | |
touch = 1 if eos > 0.5 else 0 | |
commands.append([dx, dy, touch]) | |
gen_input = torch.FloatTensor([dx, dy, touch]).view([1, 1, 3]).to(self.device) | |
character_nums += 1 | |
# print (zz, term) | |
if term > 0.3: | |
if target_sentence[current_char_id_count] == ' ': | |
current_char_id_count += 1 | |
character_nums = 0 | |
if current_char_id_count == len(W_c_rec): | |
break | |
elif character_nums > 5: | |
current_char_id_count += 1 | |
character_nums = 0 | |
if current_char_id_count == len(W_c_rec): | |
break | |
cx += dx * 2.0 * 5.0 | |
cy += dy * 2.0 * 5.0 | |
if cx > 1000 or cx < 0: | |
break | |
if cy > 350 or cy < 0: | |
break | |
return commands | |
def sample_from_w_fix(self, W_c_rec): | |
gen_input = torch.zeros([1, 1, 3]).to(self.device) | |
current_char_id_count = 0 | |
gc1 = torch.zeros([self.num_layers, 1, self.weight_dim]).to(self.device) | |
gh1 = torch.zeros([self.num_layers, 1, self.weight_dim]).to(self.device) | |
gc2 = torch.zeros([self.num_layers, 1, self.weight_dim * 2]).to(self.device) | |
gh2 = torch.zeros([self.num_layers, 1, self.weight_dim * 2]).to(self.device) | |
terms = [] | |
commands = [] | |
character_nums = 0 | |
cx, cy = 100, 150 | |
new_char = False | |
renewal = False | |
for zz in range(800): | |
# print (torch.sum(gc1)) | |
W_c_t_now = W_c_rec[current_char_id_count:current_char_id_count + 1] | |
gen_state = self.gen_state_fc1(gen_input) | |
gen_state = self.gen_state_relu(gen_state) | |
gen_state, (gc1, gh1) = self.gen_state_lstm1(gen_state, (gc1, gh1)) | |
gen_encoded = gen_state.squeeze(0) | |
gen_lstm2_input = torch.cat([gen_encoded, W_c_t_now], -1) | |
gen_lstm2_input = gen_lstm2_input.view([1, 1, self.weight_dim * 2]) | |
gen_out, (gc2, gh2) = self.gen_state_lstm2(gen_lstm2_input, (gc2, gh2)) | |
gen_out = gen_out.squeeze(0) | |
mdn_out = self.gen_state_fc2(gen_out) | |
term_out = self.term_fc1(gen_out) | |
term_out = self.term_relu1(term_out) | |
term_out = self.term_fc2(term_out) | |
term_out = self.term_relu2(term_out) | |
term_out = self.term_fc3(term_out) | |
term = self.term_sigmoid(term_out) | |
eos = self.mdn_sigmoid(mdn_out[:, 0]) | |
[mu1, mu2, sig1, sig2, rho, pi] = torch.split(mdn_out[:, 1:], self.num_mixtures, 1) | |
sig1 = sig1.exp() + 1e-3 | |
sig2 = sig2.exp() + 1e-3 | |
rho = self.mdn_tanh(rho) | |
pi = self.mdn_softmax(pi) | |
mus = torch.stack([mu1, mu2], -1).squeeze() | |
sigs = torch.stack([sig1, sig2], -1).squeeze() * self.scale_sd | |
distribution = torch.distributions.normal.Normal(loc=mus, scale=sigs) | |
sample = distribution.sample() | |
min_clamp = distribution.icdf(0.5 - torch.ones_like(mus) * self.clamp_mdn/2) | |
max_clamp = distribution.icdf(0.5 + torch.ones_like(mus) * self.clamp_mdn/2) | |
sample = sample.clamp(min=min_clamp, max=max_clamp) | |
pi = pi.cpu().detach().numpy() | |
mus = mus.cpu().detach().numpy() | |
rho = rho.cpu().detach().numpy()[0] | |
eos = eos.cpu().detach().numpy()[0] | |
term = term.cpu().detach().numpy()[0][0] | |
sample = sample.cpu().detach().numpy() | |
terms.append(term) | |
[dx, dy] = np.sum(pi.reshape(20, 1) * sample, 0) | |
touch = 1 if eos > 0.5 else 0 | |
if new_char and touch == 1: | |
new_char = False | |
commands.append([dx, dy, touch]) | |
return commands, current_char_id_count | |
else: | |
commands.append([dx, dy, touch]) | |
gen_input = torch.FloatTensor([dx, dy, touch]).view([1, 1, 3]).to(self.device) | |
character_nums += 1 | |
# print (zz, term) | |
if term > 0.5: | |
if character_nums > 5: | |
current_char_id_count += 1 | |
character_nums = 0 | |
new_char = True | |
if current_char_id_count == len(W_c_rec): | |
break | |
cx += dx * 2.0 * 5.0 | |
cy += dy * 2.0 * 5.0 | |
if cx > 1000 or cx < 0: | |
break | |
if cy > 350 or cy < 0: | |
break | |
return commands, -1 |