augan / AUGAN.py
qninhdt's picture
Upload 28 files
89ce6b3 verified
from collections import namedtuple
from models import generator_resnet, discriminator
from utils import *
from loss_utils import *
from ops import *
import time
import matplotlib.pyplot as plt
from glob import glob
class AUGAN(object):
def __init__(self, sess, args):
self.sess = sess
self.batch_size = args.batch_size
self.image_size = args.fine_size
self.input_c_dim = args.input_nc
self.output_c_dim = args.output_nc
self.L1_lambda = args.L1_lambda
self.conf_lambda = args.conf_lambda
self.dataset_dir = args.dataset_dir
self.n_d = args.n_d
self.n_scale = args.n_scale
self.ndf = args.ndf
self.load_size = args.load_size
self.fine_size = args.fine_size
self.generator = generator_resnet
self.discriminator = discriminator
if args.use_lsgan:
self.criterionGAN = mae_criterion
self.criterionGAN_list = mae_criterion_list
else:
self.criterionGAN = sce_criterion
self.criterionGAN_list = sce_criterion_list
self.use_uncertainty = args.use_uncertainty
OPTIONS = namedtuple(
"OPTIONS",
"batch_size image_size \
gf_dim df_dim output_c_dim is_training",
)
self.options = OPTIONS._make(
(
args.batch_size,
args.fine_size,
args.ngf,
args.ndf // args.n_d,
args.output_nc,
args.phase == "train",
)
)
self.save_conf = args.save_conf
self._build_model()
self.saver = tf.compat.v1.train.Saver()
self.pool = ImagePool(args.max_size)
def _build_model(self):
self.real_data = tf.compat.v1.placeholder(
tf.float32,
[
self.batch_size,
self.image_size,
self.image_size * 2,
self.input_c_dim + self.output_c_dim,
],
name="real_A_and_B_images",
)
self.real_A = self.real_data[:, :, :, : self.input_c_dim]
self.real_B = self.real_data[
:, :, :, self.input_c_dim : self.input_c_dim + self.output_c_dim
]
A_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
B_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
A_label[:, :, :, 0] = 1.0
B_label[:, :, :, 1] = 1.0
self.A_label = tf.convert_to_tensor(A_label)
self.B_label = tf.convert_to_tensor(B_label)
(
self.fake_B,
self.rec_realA,
self.realA_percep,
self.transA_percep,
self.pred_confA,
) = self.generator(
self.real_A, self.options, transfer=True, reuse=False, name="generatorA2B"
)
self.fake_A_, self.rec_fakeB, self.fakeB_percep, _, _ = self.generator(
self.fake_B, self.options, transfer=False, reuse=False, name="generatorB2A"
)
self.fake_A, self.rec_realB, self.realB_percep, _, _ = self.generator(
self.real_B, self.options, transfer=False, reuse=True, name="generatorB2A"
)
self.fake_B_, self.rec_fakeA, self.fakeA_percep, self.trans_fakeA_percep, _ = (
self.generator(
self.fake_A,
self.options,
transfer=True,
reuse=True,
name="generatorA2B",
)
)
self.g_adv_total = 0.0
self.g_adv = 0.0
self.g_adv_rec = 0.0
self.g_adv_recfake = 0.0
self.percep_loss = tf.reduce_mean(
tf.abs(
tf.reduce_mean(self.transA_percep, axis=3)
- tf.reduce_mean(self.fakeB_percep, axis=3)
)
) + tf.reduce_mean(
tf.abs(
tf.reduce_mean(self.realB_percep, axis=3)
- tf.reduce_mean(self.fakeA_percep, axis=3)
)
)
for i in range(self.n_d):
self.DB_fake = self.discriminator(
self.fake_B, self.options, reuse=False, name=str(i) + "_discriminatorB"
)
self.DA_fake = self.discriminator(
self.fake_A, self.options, reuse=False, name=str(i) + "_discriminatorA"
)
self.g_adv_total += self.criterionGAN_list(
self.DA_fake, get_ones_like(self.DA_fake)
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
self.g_adv += self.criterionGAN_list(
self.DA_fake, get_ones_like(self.DA_fake)
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
self.g_loss_a2b = (
self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
)
self.g_loss_b2a = (
self.criterionGAN_list(self.DA_fake, get_ones_like(self.DA_fake))
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
)
self.g_A_recon_loss = self.L1_lambda * abs_criterion(
self.rec_realA, self.real_A
)
self.g_B_recon_loss = self.L1_lambda * abs_criterion(
self.rec_realB, self.real_B
)
if self.use_uncertainty:
self.g_A_cycle_loss = self.conf_lambda * conf_criterion(
self.real_A, self.fake_A_, self.pred_confA
)
else:
self.g_A_cycle_loss = self.L1_lambda * abs_criterion(
self.real_A, self.fake_A_
)
self.g_B_cylce_loss = self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.g_loss = (
self.g_adv_total
+ self.g_A_recon_loss
+ self.g_B_recon_loss
+ self.g_A_cycle_loss
+ self.g_B_cylce_loss
+ self.percep_loss
)
self.g_rec_real = abs_criterion(self.rec_realA, self.real_A) + abs_criterion(
self.rec_realB, self.real_B
)
self.g_rec_cycle = abs_criterion(self.real_A, self.fake_A_) + abs_criterion(
self.real_B, self.fake_B_
)
self.fake_A_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="fake_A_sample",
)
self.fake_B_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="fake_B_sample",
)
self.rec_A_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="rec_A_sample",
)
self.rec_B_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="rec_B_sample",
)
self.rec_fakeA_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="rec_fakeA_sample",
)
self.rec_fakeB_sample = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="rec_fakeB_sample",
)
self.d_loss_item = []
self.d_loss_item_rec = []
self.d_loss_item_recfake = []
for i in range(self.n_d):
self.DB_real = self.discriminator(
self.real_B, self.options, reuse=True, name=str(i) + "_discriminatorB"
)
self.DA_real = self.discriminator(
self.real_A, self.options, reuse=True, name=str(i) + "_discriminatorA"
)
self.DB_fake_sample = self.discriminator(
self.fake_B_sample,
self.options,
reuse=True,
name=str(i) + "_discriminatorB",
)
self.DA_fake_sample = self.discriminator(
self.fake_A_sample,
self.options,
reuse=True,
name=str(i) + "_discriminatorA",
)
self.db_loss_real = self.criterionGAN_list(
self.DB_real, get_ones_like(self.DB_real)
)
self.db_loss_fake = self.criterionGAN_list(
self.DB_fake_sample, get_zeros_like(self.DB_fake_sample)
)
self.db_loss = self.db_loss_real * 0.5 + self.db_loss_fake * 0.5
self.da_loss_real = self.criterionGAN_list(
self.DA_real, get_ones_like(self.DA_real)
)
self.da_loss_fake = self.criterionGAN_list(
self.DA_fake_sample, get_zeros_like(self.DA_fake_sample)
)
self.da_loss = self.da_loss_real * 0.5 + self.da_loss_fake * 0.5
self.d_loss = self.da_loss + self.db_loss
self.d_loss_item.append(self.d_loss)
self.g_loss_a2b_sum = tf.compat.v1.summary.scalar("g_loss_a2b", self.g_loss_a2b)
self.g_loss_b2a_sum = tf.compat.v1.summary.scalar("g_loss_b2a", self.g_loss_b2a)
self.g_loss_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss)
self.g_sum = tf.compat.v1.summary.merge(
[self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum]
)
self.db_loss_sum = tf.compat.v1.summary.scalar("db_loss", self.db_loss)
self.da_loss_sum = tf.compat.v1.summary.scalar("da_loss", self.da_loss)
self.d_loss_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss)
self.db_loss_real_sum = tf.compat.v1.summary.scalar(
"db_loss_real", self.db_loss_real
)
self.db_loss_fake_sum = tf.compat.v1.summary.scalar(
"db_loss_fake", self.db_loss_fake
)
self.da_loss_real_sum = tf.compat.v1.summary.scalar(
"da_loss_real", self.da_loss_real
)
self.da_loss_fake_sum = tf.compat.v1.summary.scalar(
"da_loss_fake", self.da_loss_fake
)
self.d_sum = tf.compat.v1.summary.merge(
[
self.da_loss_sum,
self.da_loss_real_sum,
self.da_loss_fake_sum,
self.db_loss_sum,
self.db_loss_real_sum,
self.db_loss_fake_sum,
self.d_loss_sum,
]
)
self.test_A = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.input_c_dim],
name="test_A",
)
self.test_B = tf.compat.v1.placeholder(
tf.float32,
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
name="test_B",
)
(
self.testB,
self.rec_testA,
self.testA_percep,
self.trans_testA_percep,
self.test_pred_confA,
) = self.generator(
self.test_A, self.options, transfer=True, reuse=True, name="generatorA2B"
)
self.rec_cycle_A, self.refine_testB, self.testB_percep, _, _ = self.generator(
self.testB, self.options, transfer=False, reuse=True, name="generatorB2A"
)
self.testA, self.rec_testB, _, _, _ = self.generator(
self.test_B, self.options, transfer=False, reuse=True, name="generatorB2A"
)
self.rec_cycle_B, self.refine_testA, _, _, _ = self.generator(
self.testA, self.options, True, True, name="generatorA2B"
)
t_vars = tf.compat.v1.trainable_variables()
self.g_vars = [var for var in t_vars if "generator" in var.name]
self.p_vars = [var for var in t_vars if "percep" in var.name]
self.d_vars_item = []
for i in range(self.n_d):
self.d_vars = [
var for var in t_vars if str(i) + "_discriminator" in var.name
]
self.d_vars_item.append(self.d_vars)
def train(self, args):
self.lr = tf.compat.v1.placeholder(tf.float32, None, name="learning_rate")
### generator
self.g_optim = tf.optimizers.Adam(
learning_rate=self.lr, beta_1=args.beta1
).minimize(self.g_loss, var_list=self.g_vars, tape=None)
### translation
self.d_optim_item = []
for i in range(self.n_d):
self.d_optim = tf.optimizers.Adam(
learning_rate=self.lr, beta_1=args.beta1
).minimize(self.g_loss, var_list=self.g_vars, tape=None)
self.d_optim_item.append(self.d_optim)
init_op = tf.compat.v1.global_variables_initializer()
self.sess.run(init_op)
self.writer = tf.summary.FileWriter(
os.path.join(args.checkpoint_dir, "logs"), self.sess.graph
)
counter = 1
start_time = time.time()
if args.continue_train:
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
print("Training.........................")
for epoch in range(args.epoch):
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainA"))
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainB"))
if (len(dataA) == 0) or (len(dataB) == 0):
raise Exception("No files found in the dataset")
else:
print(
"Data found in the dataset. length of A: ",
len(dataA),
" B: ",
len(dataB),
)
np.random.shuffle(dataA)
np.random.shuffle(dataB)
batch_idxs = (
min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
)
lr = (
args.lr
if epoch < args.epoch_step
else args.lr * (args.epoch - epoch) / (args.epoch - args.epoch_step)
)
for idx in range(0, batch_idxs):
print("Epoch: [%2d] [%4d/%4d] " % (epoch, idx, batch_idxs))
batch_files = list(
zip(
dataA[idx * self.batch_size : (idx + 1) * self.batch_size],
dataB[idx * self.batch_size : (idx + 1) * self.batch_size],
)
)
batch_images = [
load_train_data(batch_file, args.load_size, args.fine_size)
for batch_file in batch_files
]
batch_images = np.array(batch_images).astype(np.float32)
# Update G network and record fake outputs
print("Training G network----------------------")
(
fake_A,
fake_B,
rec_A,
rec_B,
rec_fake_A,
rec_fake_B,
_,
g_loss,
gan_loss,
percep,
g_adv,
g_A_recon_loss,
g_B_recon_loss,
g_A_cycle_loss,
g_B_cycle_loss,
summary_str,
) = self.sess.run(
[
self.fake_A,
self.fake_B,
self.rec_realA,
self.rec_realB,
self.rec_fakeA,
self.rec_fakeB,
self.g_optim,
self.g_loss,
self.g_adv_total,
self.percep_loss,
self.g_adv,
self.g_A_recon_loss,
self.g_B_recon_loss,
self.g_A_cycle_loss,
self.g_B_cylce_loss,
self.g_sum,
],
feed_dict={self.real_data: batch_images, self.lr: lr},
)
self.writer.add_summary(summary_str, counter)
[fake_A, fake_B] = self.pool([fake_A, fake_B])
# Update D network
print("Training D network----------------------")
loss_print = []
for i in range(self.n_d):
_, d_loss, d_sum = self.sess.run(
[self.d_optim_item[i], self.d_loss_item[i], self.d_sum],
feed_dict={
self.real_data: batch_images,
self.fake_A_sample: fake_A,
self.fake_B_sample: fake_B,
self.lr: lr,
},
)
loss_print.append(d_loss)
counter += 1
print(
(
"Epoch: [%2d] [%4d/%4d] time: %4.4f g_loss: %4.4f gan:%4.4f adv:%4.4f g_percep:%4.4f "
% (
epoch,
idx,
batch_idxs,
time.time() - start_time,
g_loss,
gan_loss,
g_adv,
percep,
)
)
)
if np.mod(counter, args.print_freq) == 1:
self.sample_model(args.sample_dir, epoch, idx)
if np.mod(counter, args.save_freq) == 2:
self.save(args.checkpoint_dir, counter)
def save(self, checkpoint_dir, step):
model_name = "cyclegan.model"
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(
self.sess, os.path.join(checkpoint_dir, model_name), global_step=step
)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoint...")
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
def sample_model(self, sample_dir, epoch, idx):
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
if (len(dataA) == 0) or (len(dataB) == 0):
raise Exception("No files found in the test directory")
np.random.shuffle(dataA)
np.random.shuffle(dataB)
batch_files = list(zip(dataA[: self.batch_size], dataB[: self.batch_size]))
sample_images = [
load_train_data(batch_file, self.load_size, self.fine_size, is_testing=True)
for batch_file in batch_files
]
sample_images = np.array(sample_images).astype(np.float32)
fake_A, fake_B = self.sess.run(
[self.fake_A, self.fake_B], feed_dict={self.real_data: sample_images}
)
real_A = sample_images[:, :, :, :3]
real_B = sample_images[:, :, :, 3:]
merge_A = np.concatenate([real_B, fake_A], axis=2)
merge_B = np.concatenate([real_A, fake_B], axis=2)
check_folder("./{}/{:02d}".format(sample_dir, epoch))
save_images(
merge_A,
[self.batch_size, 1],
"./{}/{:02d}/A_{:04d}.jpg".format(sample_dir, epoch, idx),
)
save_images(
merge_B,
[self.batch_size, 1],
"./{}/{:02d}/B_{:04d}.jpg".format(sample_dir, epoch, idx),
)
def test(self, args):
total_time = 0
init_op = tf.compat.v1.global_variables_initializer()
self.sess.run(init_op)
if args.which_direction == "AtoB":
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
elif args.which_direction == "BtoA":
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
else:
raise Exception("--which_direction must be AtoB or BtoA")
if len(sample_files) == 0:
raise Exception("No files found in the test directory")
# print(sample_files)
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
(
self.testB,
self.refine_testB,
self.test_A,
self.rec_testA,
self.rec_cycle_A,
self.testA_percep,
self.test_pred_confA,
)
if args.which_direction == "AtoB"
else (
self.testA,
self.refine_testA,
self.test_B,
self.rec_testB,
self.rec_cycle_B,
self.testB_percep,
self.test_pred_confA,
)
)
for sample_file in sample_files:
# print('Processing image: ' + sample_file)
sample_image = [load_test_data(sample_file, args.fine_size)]
start_time = time.time()
sample_image = np.array(sample_image).astype(np.float32)
image_path = os.path.join(
args.test_dir,
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
)
ori_path = os.path.join(
args.test_dir,
"{0}_{1}".format("ori", os.path.basename(sample_file)),
)
conf_path = os.path.join(
args.conf_dir,
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
)
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
end_time = time.time()
# merge = np.concatenate([sample_image, fake_img], axis=2)
save_images(fake_img[0], [1], image_path)
save_images(sample_image[0], [1], ori_path)
# save_images(merge, [1, 1], image_path)
total_time = total_time + (end_time - start_time)
if args.save_conf:
if args.which_direction == "AtoB":
pass
else:
raise Exception(
"--conf map only can be estimated in AtoB direction"
)
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
conf_img_sq = np.squeeze(conf_img)
plt.imshow(
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
)
plt.savefig(conf_path)
print(
f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
)
def convert(self, args, datadir="./inf_data"):
total_time = 0
init_op = tf.compat.v1.global_variables_initializer()
self.sess.run(init_op)
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
raise Exception("-- Cannot Load Model. Train or Add model first")
if args.which_direction == "AtoB":
sample_files = glob(datadir)
elif args.which_direction == "BtoA":
sample_files = glob(datadir)
else:
raise Exception("--which_direction must be AtoB or BtoA")
print(sample_files)
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
(
self.testB,
self.refine_testB,
self.test_A,
self.rec_testA,
self.rec_cycle_A,
self.testA_percep,
self.test_pred_confA,
)
if args.which_direction == "AtoB"
else (
self.testA,
self.refine_testA,
self.test_B,
self.rec_testB,
self.rec_cycle_B,
self.testB_percep,
self.test_pred_confA,
)
)
for sample_file in sample_files:
print("Processing image: " + sample_file)
sample_image = [load_test_data(sample_file, args.fine_size)]
start_time = time.time()
sample_image = np.array(sample_image).astype(np.float32)
image_path = os.path.join(
args.test_dir,
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
)
conf_path = os.path.join(
args.conf_dir,
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
)
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
end_time = time.time()
merge = np.concatenate([sample_image, fake_img], axis=2)
save_images(merge, [1, 1], image_path)
total_time = total_time + (end_time - start_time)
print(f"Time taken to convert image: {end_time - start_time} seconds")
if args.save_conf:
if args.which_direction == "AtoB":
pass
else:
raise Exception(
"--conf map only can be estimated in AtoB direction"
)
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
conf_img_sq = np.squeeze(conf_img)
plt.imshow(
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
)
plt.savefig(conf_path)
print(
f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
)
def convert_image(self, args, input_image_path, output_dir):
init_op = tf.compat.v1.global_variables_initializer()
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
with tf.Session() as sess:
sess.run(init_op)
# Load the input image
input_image = [load_test_data(input_image_path, self.fine_size)]
input_image = np.array(input_image).astype(np.float32)
# Get the generator output
if args.which_direction == "AtoB":
out_var = self.testB
in_var = self.test_A
else:
out_var = self.testA
in_var = self.test_B
# Run the model to obtain the converted image
start_time = time.time()
converted_image = sess.run(out_var, feed_dict={in_var: input_image})
end_time = time.time()
# Save the converted image
output_image_path = os.path.join(
output_dir, os.path.basename(input_image_path)
)
merge = np.concatenate([input_image, converted_image], axis=2)
save_images(merge, [1, 1], output_image_path)
# Print the time taken
print(f"Time taken to convert image: {end_time - start_time} seconds")