|
from collections import namedtuple |
|
from models import generator_resnet, discriminator |
|
from utils import * |
|
from loss_utils import * |
|
from ops import * |
|
import time |
|
import matplotlib.pyplot as plt |
|
from glob import glob |
|
|
|
|
|
class AUGAN(object): |
|
def __init__(self, sess, args): |
|
self.sess = sess |
|
self.batch_size = args.batch_size |
|
self.image_size = args.fine_size |
|
self.input_c_dim = args.input_nc |
|
self.output_c_dim = args.output_nc |
|
self.L1_lambda = args.L1_lambda |
|
self.conf_lambda = args.conf_lambda |
|
self.dataset_dir = args.dataset_dir |
|
self.n_d = args.n_d |
|
self.n_scale = args.n_scale |
|
self.ndf = args.ndf |
|
self.load_size = args.load_size |
|
self.fine_size = args.fine_size |
|
self.generator = generator_resnet |
|
self.discriminator = discriminator |
|
if args.use_lsgan: |
|
self.criterionGAN = mae_criterion |
|
self.criterionGAN_list = mae_criterion_list |
|
else: |
|
self.criterionGAN = sce_criterion |
|
self.criterionGAN_list = sce_criterion_list |
|
|
|
self.use_uncertainty = args.use_uncertainty |
|
|
|
OPTIONS = namedtuple( |
|
"OPTIONS", |
|
"batch_size image_size \ |
|
gf_dim df_dim output_c_dim is_training", |
|
) |
|
self.options = OPTIONS._make( |
|
( |
|
args.batch_size, |
|
args.fine_size, |
|
args.ngf, |
|
args.ndf // args.n_d, |
|
args.output_nc, |
|
args.phase == "train", |
|
) |
|
) |
|
self.save_conf = args.save_conf |
|
self._build_model() |
|
self.saver = tf.compat.v1.train.Saver() |
|
self.pool = ImagePool(args.max_size) |
|
|
|
def _build_model(self): |
|
self.real_data = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[ |
|
self.batch_size, |
|
self.image_size, |
|
self.image_size * 2, |
|
self.input_c_dim + self.output_c_dim, |
|
], |
|
name="real_A_and_B_images", |
|
) |
|
|
|
self.real_A = self.real_data[:, :, :, : self.input_c_dim] |
|
self.real_B = self.real_data[ |
|
:, :, :, self.input_c_dim : self.input_c_dim + self.output_c_dim |
|
] |
|
|
|
A_label = np.zeros([1, 1, 1, 2], dtype=np.float32) |
|
B_label = np.zeros([1, 1, 1, 2], dtype=np.float32) |
|
A_label[:, :, :, 0] = 1.0 |
|
B_label[:, :, :, 1] = 1.0 |
|
self.A_label = tf.convert_to_tensor(A_label) |
|
self.B_label = tf.convert_to_tensor(B_label) |
|
|
|
( |
|
self.fake_B, |
|
self.rec_realA, |
|
self.realA_percep, |
|
self.transA_percep, |
|
self.pred_confA, |
|
) = self.generator( |
|
self.real_A, self.options, transfer=True, reuse=False, name="generatorA2B" |
|
) |
|
self.fake_A_, self.rec_fakeB, self.fakeB_percep, _, _ = self.generator( |
|
self.fake_B, self.options, transfer=False, reuse=False, name="generatorB2A" |
|
) |
|
self.fake_A, self.rec_realB, self.realB_percep, _, _ = self.generator( |
|
self.real_B, self.options, transfer=False, reuse=True, name="generatorB2A" |
|
) |
|
self.fake_B_, self.rec_fakeA, self.fakeA_percep, self.trans_fakeA_percep, _ = ( |
|
self.generator( |
|
self.fake_A, |
|
self.options, |
|
transfer=True, |
|
reuse=True, |
|
name="generatorA2B", |
|
) |
|
) |
|
|
|
self.g_adv_total = 0.0 |
|
self.g_adv = 0.0 |
|
self.g_adv_rec = 0.0 |
|
self.g_adv_recfake = 0.0 |
|
|
|
self.percep_loss = tf.reduce_mean( |
|
tf.abs( |
|
tf.reduce_mean(self.transA_percep, axis=3) |
|
- tf.reduce_mean(self.fakeB_percep, axis=3) |
|
) |
|
) + tf.reduce_mean( |
|
tf.abs( |
|
tf.reduce_mean(self.realB_percep, axis=3) |
|
- tf.reduce_mean(self.fakeA_percep, axis=3) |
|
) |
|
) |
|
|
|
for i in range(self.n_d): |
|
self.DB_fake = self.discriminator( |
|
self.fake_B, self.options, reuse=False, name=str(i) + "_discriminatorB" |
|
) |
|
self.DA_fake = self.discriminator( |
|
self.fake_A, self.options, reuse=False, name=str(i) + "_discriminatorA" |
|
) |
|
|
|
self.g_adv_total += self.criterionGAN_list( |
|
self.DA_fake, get_ones_like(self.DA_fake) |
|
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake)) |
|
|
|
self.g_adv += self.criterionGAN_list( |
|
self.DA_fake, get_ones_like(self.DA_fake) |
|
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake)) |
|
|
|
self.g_loss_a2b = ( |
|
self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake)) |
|
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) |
|
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) |
|
) |
|
self.g_loss_b2a = ( |
|
self.criterionGAN_list(self.DA_fake, get_ones_like(self.DA_fake)) |
|
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) |
|
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) |
|
) |
|
|
|
self.g_A_recon_loss = self.L1_lambda * abs_criterion( |
|
self.rec_realA, self.real_A |
|
) |
|
self.g_B_recon_loss = self.L1_lambda * abs_criterion( |
|
self.rec_realB, self.real_B |
|
) |
|
if self.use_uncertainty: |
|
self.g_A_cycle_loss = self.conf_lambda * conf_criterion( |
|
self.real_A, self.fake_A_, self.pred_confA |
|
) |
|
else: |
|
self.g_A_cycle_loss = self.L1_lambda * abs_criterion( |
|
self.real_A, self.fake_A_ |
|
) |
|
self.g_B_cylce_loss = self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) |
|
|
|
self.g_loss = ( |
|
self.g_adv_total |
|
+ self.g_A_recon_loss |
|
+ self.g_B_recon_loss |
|
+ self.g_A_cycle_loss |
|
+ self.g_B_cylce_loss |
|
+ self.percep_loss |
|
) |
|
|
|
self.g_rec_real = abs_criterion(self.rec_realA, self.real_A) + abs_criterion( |
|
self.rec_realB, self.real_B |
|
) |
|
self.g_rec_cycle = abs_criterion(self.real_A, self.fake_A_) + abs_criterion( |
|
self.real_B, self.fake_B_ |
|
) |
|
|
|
self.fake_A_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="fake_A_sample", |
|
) |
|
self.fake_B_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="fake_B_sample", |
|
) |
|
self.rec_A_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="rec_A_sample", |
|
) |
|
self.rec_B_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="rec_B_sample", |
|
) |
|
self.rec_fakeA_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="rec_fakeA_sample", |
|
) |
|
self.rec_fakeB_sample = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="rec_fakeB_sample", |
|
) |
|
|
|
self.d_loss_item = [] |
|
self.d_loss_item_rec = [] |
|
self.d_loss_item_recfake = [] |
|
|
|
for i in range(self.n_d): |
|
self.DB_real = self.discriminator( |
|
self.real_B, self.options, reuse=True, name=str(i) + "_discriminatorB" |
|
) |
|
self.DA_real = self.discriminator( |
|
self.real_A, self.options, reuse=True, name=str(i) + "_discriminatorA" |
|
) |
|
self.DB_fake_sample = self.discriminator( |
|
self.fake_B_sample, |
|
self.options, |
|
reuse=True, |
|
name=str(i) + "_discriminatorB", |
|
) |
|
self.DA_fake_sample = self.discriminator( |
|
self.fake_A_sample, |
|
self.options, |
|
reuse=True, |
|
name=str(i) + "_discriminatorA", |
|
) |
|
self.db_loss_real = self.criterionGAN_list( |
|
self.DB_real, get_ones_like(self.DB_real) |
|
) |
|
self.db_loss_fake = self.criterionGAN_list( |
|
self.DB_fake_sample, get_zeros_like(self.DB_fake_sample) |
|
) |
|
self.db_loss = self.db_loss_real * 0.5 + self.db_loss_fake * 0.5 |
|
self.da_loss_real = self.criterionGAN_list( |
|
self.DA_real, get_ones_like(self.DA_real) |
|
) |
|
self.da_loss_fake = self.criterionGAN_list( |
|
self.DA_fake_sample, get_zeros_like(self.DA_fake_sample) |
|
) |
|
self.da_loss = self.da_loss_real * 0.5 + self.da_loss_fake * 0.5 |
|
self.d_loss = self.da_loss + self.db_loss |
|
self.d_loss_item.append(self.d_loss) |
|
|
|
self.g_loss_a2b_sum = tf.compat.v1.summary.scalar("g_loss_a2b", self.g_loss_a2b) |
|
self.g_loss_b2a_sum = tf.compat.v1.summary.scalar("g_loss_b2a", self.g_loss_b2a) |
|
self.g_loss_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss) |
|
self.g_sum = tf.compat.v1.summary.merge( |
|
[self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum] |
|
) |
|
self.db_loss_sum = tf.compat.v1.summary.scalar("db_loss", self.db_loss) |
|
self.da_loss_sum = tf.compat.v1.summary.scalar("da_loss", self.da_loss) |
|
self.d_loss_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss) |
|
self.db_loss_real_sum = tf.compat.v1.summary.scalar( |
|
"db_loss_real", self.db_loss_real |
|
) |
|
self.db_loss_fake_sum = tf.compat.v1.summary.scalar( |
|
"db_loss_fake", self.db_loss_fake |
|
) |
|
self.da_loss_real_sum = tf.compat.v1.summary.scalar( |
|
"da_loss_real", self.da_loss_real |
|
) |
|
self.da_loss_fake_sum = tf.compat.v1.summary.scalar( |
|
"da_loss_fake", self.da_loss_fake |
|
) |
|
self.d_sum = tf.compat.v1.summary.merge( |
|
[ |
|
self.da_loss_sum, |
|
self.da_loss_real_sum, |
|
self.da_loss_fake_sum, |
|
self.db_loss_sum, |
|
self.db_loss_real_sum, |
|
self.db_loss_fake_sum, |
|
self.d_loss_sum, |
|
] |
|
) |
|
|
|
self.test_A = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.input_c_dim], |
|
name="test_A", |
|
) |
|
self.test_B = tf.compat.v1.placeholder( |
|
tf.float32, |
|
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim], |
|
name="test_B", |
|
) |
|
|
|
( |
|
self.testB, |
|
self.rec_testA, |
|
self.testA_percep, |
|
self.trans_testA_percep, |
|
self.test_pred_confA, |
|
) = self.generator( |
|
self.test_A, self.options, transfer=True, reuse=True, name="generatorA2B" |
|
) |
|
self.rec_cycle_A, self.refine_testB, self.testB_percep, _, _ = self.generator( |
|
self.testB, self.options, transfer=False, reuse=True, name="generatorB2A" |
|
) |
|
|
|
self.testA, self.rec_testB, _, _, _ = self.generator( |
|
self.test_B, self.options, transfer=False, reuse=True, name="generatorB2A" |
|
) |
|
self.rec_cycle_B, self.refine_testA, _, _, _ = self.generator( |
|
self.testA, self.options, True, True, name="generatorA2B" |
|
) |
|
|
|
t_vars = tf.compat.v1.trainable_variables() |
|
|
|
self.g_vars = [var for var in t_vars if "generator" in var.name] |
|
self.p_vars = [var for var in t_vars if "percep" in var.name] |
|
self.d_vars_item = [] |
|
for i in range(self.n_d): |
|
self.d_vars = [ |
|
var for var in t_vars if str(i) + "_discriminator" in var.name |
|
] |
|
self.d_vars_item.append(self.d_vars) |
|
|
|
def train(self, args): |
|
|
|
self.lr = tf.compat.v1.placeholder(tf.float32, None, name="learning_rate") |
|
|
|
|
|
self.g_optim = tf.optimizers.Adam( |
|
learning_rate=self.lr, beta_1=args.beta1 |
|
).minimize(self.g_loss, var_list=self.g_vars, tape=None) |
|
|
|
|
|
self.d_optim_item = [] |
|
for i in range(self.n_d): |
|
self.d_optim = tf.optimizers.Adam( |
|
learning_rate=self.lr, beta_1=args.beta1 |
|
).minimize(self.g_loss, var_list=self.g_vars, tape=None) |
|
self.d_optim_item.append(self.d_optim) |
|
|
|
init_op = tf.compat.v1.global_variables_initializer() |
|
self.sess.run(init_op) |
|
self.writer = tf.summary.FileWriter( |
|
os.path.join(args.checkpoint_dir, "logs"), self.sess.graph |
|
) |
|
|
|
counter = 1 |
|
start_time = time.time() |
|
|
|
if args.continue_train: |
|
if self.load(args.checkpoint_dir): |
|
print(" [*] Load SUCCESS") |
|
else: |
|
print(" [!] Load failed...") |
|
|
|
print("Training.........................") |
|
for epoch in range(args.epoch): |
|
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainA")) |
|
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainB")) |
|
if (len(dataA) == 0) or (len(dataB) == 0): |
|
raise Exception("No files found in the dataset") |
|
else: |
|
print( |
|
"Data found in the dataset. length of A: ", |
|
len(dataA), |
|
" B: ", |
|
len(dataB), |
|
) |
|
np.random.shuffle(dataA) |
|
np.random.shuffle(dataB) |
|
batch_idxs = ( |
|
min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size |
|
) |
|
lr = ( |
|
args.lr |
|
if epoch < args.epoch_step |
|
else args.lr * (args.epoch - epoch) / (args.epoch - args.epoch_step) |
|
) |
|
|
|
for idx in range(0, batch_idxs): |
|
print("Epoch: [%2d] [%4d/%4d] " % (epoch, idx, batch_idxs)) |
|
batch_files = list( |
|
zip( |
|
dataA[idx * self.batch_size : (idx + 1) * self.batch_size], |
|
dataB[idx * self.batch_size : (idx + 1) * self.batch_size], |
|
) |
|
) |
|
batch_images = [ |
|
load_train_data(batch_file, args.load_size, args.fine_size) |
|
for batch_file in batch_files |
|
] |
|
batch_images = np.array(batch_images).astype(np.float32) |
|
|
|
print("Training G network----------------------") |
|
( |
|
fake_A, |
|
fake_B, |
|
rec_A, |
|
rec_B, |
|
rec_fake_A, |
|
rec_fake_B, |
|
_, |
|
g_loss, |
|
gan_loss, |
|
percep, |
|
g_adv, |
|
g_A_recon_loss, |
|
g_B_recon_loss, |
|
g_A_cycle_loss, |
|
g_B_cycle_loss, |
|
summary_str, |
|
) = self.sess.run( |
|
[ |
|
self.fake_A, |
|
self.fake_B, |
|
self.rec_realA, |
|
self.rec_realB, |
|
self.rec_fakeA, |
|
self.rec_fakeB, |
|
self.g_optim, |
|
self.g_loss, |
|
self.g_adv_total, |
|
self.percep_loss, |
|
self.g_adv, |
|
self.g_A_recon_loss, |
|
self.g_B_recon_loss, |
|
self.g_A_cycle_loss, |
|
self.g_B_cylce_loss, |
|
self.g_sum, |
|
], |
|
feed_dict={self.real_data: batch_images, self.lr: lr}, |
|
) |
|
self.writer.add_summary(summary_str, counter) |
|
[fake_A, fake_B] = self.pool([fake_A, fake_B]) |
|
|
|
|
|
print("Training D network----------------------") |
|
loss_print = [] |
|
for i in range(self.n_d): |
|
_, d_loss, d_sum = self.sess.run( |
|
[self.d_optim_item[i], self.d_loss_item[i], self.d_sum], |
|
feed_dict={ |
|
self.real_data: batch_images, |
|
self.fake_A_sample: fake_A, |
|
self.fake_B_sample: fake_B, |
|
self.lr: lr, |
|
}, |
|
) |
|
|
|
loss_print.append(d_loss) |
|
|
|
counter += 1 |
|
print( |
|
( |
|
"Epoch: [%2d] [%4d/%4d] time: %4.4f g_loss: %4.4f gan:%4.4f adv:%4.4f g_percep:%4.4f " |
|
% ( |
|
epoch, |
|
idx, |
|
batch_idxs, |
|
time.time() - start_time, |
|
g_loss, |
|
gan_loss, |
|
g_adv, |
|
percep, |
|
) |
|
) |
|
) |
|
|
|
if np.mod(counter, args.print_freq) == 1: |
|
self.sample_model(args.sample_dir, epoch, idx) |
|
|
|
if np.mod(counter, args.save_freq) == 2: |
|
self.save(args.checkpoint_dir, counter) |
|
|
|
def save(self, checkpoint_dir, step): |
|
model_name = "cyclegan.model" |
|
model_dir = "%s_%s" % (self.dataset_dir, self.image_size) |
|
checkpoint_dir = os.path.join(checkpoint_dir, model_dir) |
|
|
|
if not os.path.exists(checkpoint_dir): |
|
os.makedirs(checkpoint_dir) |
|
|
|
self.saver.save( |
|
self.sess, os.path.join(checkpoint_dir, model_name), global_step=step |
|
) |
|
|
|
def load(self, checkpoint_dir): |
|
print(" [*] Reading checkpoint...") |
|
|
|
model_dir = "%s_%s" % (self.dataset_dir, self.image_size) |
|
checkpoint_dir = os.path.join(checkpoint_dir, model_dir) |
|
|
|
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
|
if ckpt and ckpt.model_checkpoint_path: |
|
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) |
|
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) |
|
return True |
|
else: |
|
return False |
|
|
|
def sample_model(self, sample_dir, epoch, idx): |
|
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA")) |
|
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB")) |
|
if (len(dataA) == 0) or (len(dataB) == 0): |
|
raise Exception("No files found in the test directory") |
|
np.random.shuffle(dataA) |
|
np.random.shuffle(dataB) |
|
batch_files = list(zip(dataA[: self.batch_size], dataB[: self.batch_size])) |
|
sample_images = [ |
|
load_train_data(batch_file, self.load_size, self.fine_size, is_testing=True) |
|
for batch_file in batch_files |
|
] |
|
sample_images = np.array(sample_images).astype(np.float32) |
|
|
|
fake_A, fake_B = self.sess.run( |
|
[self.fake_A, self.fake_B], feed_dict={self.real_data: sample_images} |
|
) |
|
real_A = sample_images[:, :, :, :3] |
|
real_B = sample_images[:, :, :, 3:] |
|
|
|
merge_A = np.concatenate([real_B, fake_A], axis=2) |
|
merge_B = np.concatenate([real_A, fake_B], axis=2) |
|
check_folder("./{}/{:02d}".format(sample_dir, epoch)) |
|
save_images( |
|
merge_A, |
|
[self.batch_size, 1], |
|
"./{}/{:02d}/A_{:04d}.jpg".format(sample_dir, epoch, idx), |
|
) |
|
save_images( |
|
merge_B, |
|
[self.batch_size, 1], |
|
"./{}/{:02d}/B_{:04d}.jpg".format(sample_dir, epoch, idx), |
|
) |
|
|
|
def test(self, args): |
|
total_time = 0 |
|
|
|
init_op = tf.compat.v1.global_variables_initializer() |
|
self.sess.run(init_op) |
|
if args.which_direction == "AtoB": |
|
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA")) |
|
elif args.which_direction == "BtoA": |
|
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB")) |
|
else: |
|
raise Exception("--which_direction must be AtoB or BtoA") |
|
|
|
if len(sample_files) == 0: |
|
raise Exception("No files found in the test directory") |
|
|
|
|
|
|
|
if self.load(args.checkpoint_dir): |
|
print(" [*] Load SUCCESS") |
|
else: |
|
print(" [!] Load failed...") |
|
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = ( |
|
( |
|
self.testB, |
|
self.refine_testB, |
|
self.test_A, |
|
self.rec_testA, |
|
self.rec_cycle_A, |
|
self.testA_percep, |
|
self.test_pred_confA, |
|
) |
|
if args.which_direction == "AtoB" |
|
else ( |
|
self.testA, |
|
self.refine_testA, |
|
self.test_B, |
|
self.rec_testB, |
|
self.rec_cycle_B, |
|
self.testB_percep, |
|
self.test_pred_confA, |
|
) |
|
) |
|
for sample_file in sample_files: |
|
|
|
sample_image = [load_test_data(sample_file, args.fine_size)] |
|
start_time = time.time() |
|
sample_image = np.array(sample_image).astype(np.float32) |
|
image_path = os.path.join( |
|
args.test_dir, |
|
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)), |
|
) |
|
ori_path = os.path.join( |
|
args.test_dir, |
|
"{0}_{1}".format("ori", os.path.basename(sample_file)), |
|
) |
|
conf_path = os.path.join( |
|
args.conf_dir, |
|
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)), |
|
) |
|
|
|
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image}) |
|
end_time = time.time() |
|
|
|
save_images(fake_img[0], [1], image_path) |
|
save_images(sample_image[0], [1], ori_path) |
|
|
|
total_time = total_time + (end_time - start_time) |
|
|
|
if args.save_conf: |
|
|
|
if args.which_direction == "AtoB": |
|
pass |
|
else: |
|
raise Exception( |
|
"--conf map only can be estimated in AtoB direction" |
|
) |
|
|
|
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image}) |
|
conf_img_sq = np.squeeze(conf_img) |
|
plt.imshow( |
|
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0 |
|
) |
|
plt.savefig(conf_path) |
|
print( |
|
f"Average time taken to convert images: {total_time/len(sample_files)} seconds" |
|
) |
|
|
|
def convert(self, args, datadir="./inf_data"): |
|
total_time = 0 |
|
|
|
init_op = tf.compat.v1.global_variables_initializer() |
|
self.sess.run(init_op) |
|
|
|
if self.load(args.checkpoint_dir): |
|
print(" [*] Load SUCCESS") |
|
else: |
|
raise Exception("-- Cannot Load Model. Train or Add model first") |
|
|
|
if args.which_direction == "AtoB": |
|
sample_files = glob(datadir) |
|
elif args.which_direction == "BtoA": |
|
sample_files = glob(datadir) |
|
else: |
|
raise Exception("--which_direction must be AtoB or BtoA") |
|
|
|
print(sample_files) |
|
|
|
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = ( |
|
( |
|
self.testB, |
|
self.refine_testB, |
|
self.test_A, |
|
self.rec_testA, |
|
self.rec_cycle_A, |
|
self.testA_percep, |
|
self.test_pred_confA, |
|
) |
|
if args.which_direction == "AtoB" |
|
else ( |
|
self.testA, |
|
self.refine_testA, |
|
self.test_B, |
|
self.rec_testB, |
|
self.rec_cycle_B, |
|
self.testB_percep, |
|
self.test_pred_confA, |
|
) |
|
) |
|
for sample_file in sample_files: |
|
print("Processing image: " + sample_file) |
|
sample_image = [load_test_data(sample_file, args.fine_size)] |
|
start_time = time.time() |
|
sample_image = np.array(sample_image).astype(np.float32) |
|
image_path = os.path.join( |
|
args.test_dir, |
|
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)), |
|
) |
|
conf_path = os.path.join( |
|
args.conf_dir, |
|
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)), |
|
) |
|
|
|
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image}) |
|
end_time = time.time() |
|
merge = np.concatenate([sample_image, fake_img], axis=2) |
|
save_images(merge, [1, 1], image_path) |
|
total_time = total_time + (end_time - start_time) |
|
print(f"Time taken to convert image: {end_time - start_time} seconds") |
|
|
|
if args.save_conf: |
|
|
|
if args.which_direction == "AtoB": |
|
pass |
|
else: |
|
raise Exception( |
|
"--conf map only can be estimated in AtoB direction" |
|
) |
|
|
|
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image}) |
|
conf_img_sq = np.squeeze(conf_img) |
|
plt.imshow( |
|
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0 |
|
) |
|
plt.savefig(conf_path) |
|
print( |
|
f"Average time taken to convert images: {total_time/len(sample_files)} seconds" |
|
) |
|
|
|
def convert_image(self, args, input_image_path, output_dir): |
|
init_op = tf.compat.v1.global_variables_initializer() |
|
if self.load(args.checkpoint_dir): |
|
print(" [*] Load SUCCESS") |
|
with tf.Session() as sess: |
|
sess.run(init_op) |
|
|
|
input_image = [load_test_data(input_image_path, self.fine_size)] |
|
input_image = np.array(input_image).astype(np.float32) |
|
|
|
|
|
if args.which_direction == "AtoB": |
|
out_var = self.testB |
|
in_var = self.test_A |
|
else: |
|
out_var = self.testA |
|
in_var = self.test_B |
|
|
|
|
|
start_time = time.time() |
|
converted_image = sess.run(out_var, feed_dict={in_var: input_image}) |
|
end_time = time.time() |
|
|
|
|
|
output_image_path = os.path.join( |
|
output_dir, os.path.basename(input_image_path) |
|
) |
|
merge = np.concatenate([input_image, converted_image], axis=2) |
|
save_images(merge, [1, 1], output_image_path) |
|
|
|
|
|
print(f"Time taken to convert image: {end_time - start_time} seconds") |
|
|