|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import time
|
|
from glob import glob
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
from collections import namedtuple
|
|
from tqdm import tqdm
|
|
import multiprocessing
|
|
|
|
from module import *
|
|
from utils import *
|
|
import prepare_dataset
|
|
import img_augm
|
|
|
|
|
|
class Artgan(object):
|
|
def __init__(self, sess, args):
|
|
self.model_name = args.model_name
|
|
self.root_dir = './models'
|
|
self.checkpoint_dir = os.path.join(self.root_dir, self.model_name, 'checkpoint')
|
|
self.checkpoint_long_dir = os.path.join(self.root_dir, self.model_name, 'checkpoint_long')
|
|
self.sample_dir = os.path.join(self.root_dir, self.model_name, 'sample')
|
|
self.inference_dir = os.path.join(self.root_dir, self.model_name, 'inference')
|
|
self.logs_dir = os.path.join(self.root_dir, self.model_name, 'logs')
|
|
|
|
self.sess = sess
|
|
self.batch_size = args.batch_size
|
|
self.image_size = args.image_size
|
|
|
|
self.loss = sce_criterion
|
|
|
|
self.initial_step = 0
|
|
|
|
OPTIONS = namedtuple('OPTIONS',
|
|
'batch_size image_size \
|
|
total_steps save_freq lr\
|
|
gf_dim df_dim \
|
|
is_training \
|
|
path_to_content_dataset \
|
|
path_to_art_dataset \
|
|
discr_loss_weight transformer_loss_weight feature_loss_weight')
|
|
self.options = OPTIONS._make((args.batch_size, args.image_size,
|
|
args.total_steps, args.save_freq, args.lr,
|
|
args.ngf, args.ndf,
|
|
args.phase == 'train',
|
|
args.path_to_content_dataset,
|
|
args.path_to_art_dataset,
|
|
args.discr_loss_weight, args.transformer_loss_weight, args.feature_loss_weight
|
|
))
|
|
|
|
|
|
if not os.path.exists(self.root_dir):
|
|
os.makedirs(self.root_dir)
|
|
if not os.path.exists(os.path.join(self.root_dir, self.model_name)):
|
|
os.makedirs(os.path.join(self.root_dir, self.model_name))
|
|
if not os.path.exists(self.checkpoint_dir):
|
|
os.makedirs(self.checkpoint_dir)
|
|
if not os.path.exists(self.checkpoint_long_dir):
|
|
os.makedirs(self.checkpoint_long_dir)
|
|
if not os.path.exists(self.sample_dir):
|
|
os.makedirs(self.sample_dir)
|
|
if not os.path.exists(self.inference_dir):
|
|
os.makedirs(self.inference_dir)
|
|
|
|
self._build_model()
|
|
|
|
|
|
self.saver = tf.train.Saver(max_to_keep=2)
|
|
self.saver_long = tf.train.Saver(max_to_keep=None)
|
|
|
|
def _build_model(self):
|
|
if self.options.is_training:
|
|
|
|
with tf.name_scope('placeholder'):
|
|
self.input_painting = tf.placeholder(dtype=tf.float32,
|
|
shape=[self.batch_size, None, None, 3],
|
|
name='painting')
|
|
self.input_photo = tf.placeholder(dtype=tf.float32,
|
|
shape=[self.batch_size, None, None, 3],
|
|
name='photo')
|
|
self.lr = tf.placeholder(dtype=tf.float32, shape=(), name='learning_rate')
|
|
|
|
|
|
|
|
self.input_photo_features = encoder(image=self.input_photo,
|
|
options=self.options,
|
|
reuse=False)
|
|
|
|
|
|
self.output_photo = decoder(features=self.input_photo_features,
|
|
options=self.options,
|
|
reuse=False)
|
|
|
|
|
|
self.output_photo_features = encoder(image=self.output_photo,
|
|
options=self.options,
|
|
reuse=True)
|
|
|
|
|
|
|
|
|
|
self.input_painting_discr_predictions = discriminator(image=self.input_painting,
|
|
options=self.options,
|
|
reuse=False)
|
|
self.input_photo_discr_predictions = discriminator(image=self.input_photo,
|
|
options=self.options,
|
|
reuse=True)
|
|
self.output_photo_discr_predictions = discriminator(image=self.output_photo,
|
|
options=self.options,
|
|
reuse=True)
|
|
|
|
|
|
|
|
|
|
|
|
scale_weight = {"scale_0": 1.,
|
|
"scale_1": 1.,
|
|
"scale_3": 1.,
|
|
"scale_5": 1.,
|
|
"scale_6": 1.}
|
|
self.input_painting_discr_loss = {key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key]
|
|
for key, pred in zip(self.input_painting_discr_predictions.keys(),
|
|
self.input_painting_discr_predictions.values())}
|
|
self.input_photo_discr_loss = {key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key]
|
|
for key, pred in zip(self.input_photo_discr_predictions.keys(),
|
|
self.input_photo_discr_predictions.values())}
|
|
self.output_photo_discr_loss = {key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key]
|
|
for key, pred in zip(self.output_photo_discr_predictions.keys(),
|
|
self.output_photo_discr_predictions.values())}
|
|
|
|
self.discr_loss = tf.add_n(list(self.input_painting_discr_loss.values())) + \
|
|
tf.add_n(list(self.input_photo_discr_loss.values())) + \
|
|
tf.add_n(list(self.output_photo_discr_loss.values()))
|
|
|
|
|
|
self.input_painting_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred > tf.zeros_like(pred)),
|
|
dtype=tf.float32)) * scale_weight[key]
|
|
for key, pred in zip(self.input_painting_discr_predictions.keys(),
|
|
self.input_painting_discr_predictions.values())}
|
|
self.input_photo_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred < tf.zeros_like(pred)),
|
|
dtype=tf.float32)) * scale_weight[key]
|
|
for key, pred in zip(self.input_photo_discr_predictions.keys(),
|
|
self.input_photo_discr_predictions.values())}
|
|
self.output_photo_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred < tf.zeros_like(pred)),
|
|
dtype=tf.float32)) * scale_weight[key]
|
|
for key, pred in zip(self.output_photo_discr_predictions.keys(),
|
|
self.output_photo_discr_predictions.values())}
|
|
self.discr_acc = (tf.add_n(list(self.input_painting_discr_acc.values())) + \
|
|
tf.add_n(list(self.input_photo_discr_acc.values())) + \
|
|
tf.add_n(list(self.output_photo_discr_acc.values()))) / float(len(scale_weight.keys())*3)
|
|
|
|
|
|
|
|
|
|
self.output_photo_gener_loss = {key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key]
|
|
for key, pred in zip(self.output_photo_discr_predictions.keys(),
|
|
self.output_photo_discr_predictions.values())}
|
|
|
|
self.gener_loss = tf.add_n(list(self.output_photo_gener_loss.values()))
|
|
|
|
|
|
self.output_photo_gener_acc = {key: tf.reduce_mean(tf.cast(x=(pred > tf.zeros_like(pred)),
|
|
dtype=tf.float32)) * scale_weight[key]
|
|
for key, pred in zip(self.output_photo_discr_predictions.keys(),
|
|
self.output_photo_discr_predictions.values())}
|
|
|
|
self.gener_acc = tf.add_n(list(self.output_photo_gener_acc.values())) / float(len(scale_weight.keys()))
|
|
|
|
|
|
|
|
self.img_loss_photo = mse_criterion(transformer_block(self.output_photo),
|
|
transformer_block(self.input_photo))
|
|
self.img_loss = self.img_loss_photo
|
|
|
|
|
|
self.feature_loss_photo = abs_criterion(self.output_photo_features, self.input_photo_features)
|
|
self.feature_loss = self.feature_loss_photo
|
|
|
|
|
|
t_vars = tf.trainable_variables()
|
|
self.discr_vars = [var for var in t_vars if 'discriminator' in var.name]
|
|
self.encoder_vars = [var for var in t_vars if 'encoder' in var.name]
|
|
self.decoder_vars = [var for var in t_vars if 'decoder' in var.name]
|
|
|
|
|
|
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
|
|
|
with tf.control_dependencies(update_ops):
|
|
self.d_optim_step = tf.train.AdamOptimizer(self.lr).minimize(
|
|
loss=self.options.discr_loss_weight * self.discr_loss,
|
|
var_list=[self.discr_vars])
|
|
self.g_optim_step = tf.train.AdamOptimizer(self.lr).minimize(
|
|
loss=self.options.discr_loss_weight * self.gener_loss +
|
|
self.options.transformer_loss_weight * self.img_loss +
|
|
self.options.feature_loss_weight * self.feature_loss,
|
|
var_list=[self.encoder_vars + self.decoder_vars])
|
|
|
|
|
|
|
|
|
|
s_d1 = [tf.summary.scalar("discriminator/input_painting_discr_loss/"+key, val)
|
|
for key, val in zip(self.input_painting_discr_loss.keys(), self.input_painting_discr_loss.values())]
|
|
s_d2 = [tf.summary.scalar("discriminator/input_photo_discr_loss/"+key, val)
|
|
for key, val in zip(self.input_photo_discr_loss.keys(), self.input_photo_discr_loss.values())]
|
|
s_d3 = [tf.summary.scalar("discriminator/output_photo_discr_loss/" + key, val)
|
|
for key, val in zip(self.output_photo_discr_loss.keys(), self.output_photo_discr_loss.values())]
|
|
s_d = tf.summary.scalar("discriminator/discr_loss", self.discr_loss)
|
|
self.summary_discriminator_loss = tf.summary.merge(s_d1+s_d2+s_d3+[s_d])
|
|
|
|
|
|
s_d1_acc = [tf.summary.scalar("discriminator/input_painting_discr_acc/"+key, val)
|
|
for key, val in zip(self.input_painting_discr_acc.keys(), self.input_painting_discr_acc.values())]
|
|
s_d2_acc = [tf.summary.scalar("discriminator/input_photo_discr_acc/"+key, val)
|
|
for key, val in zip(self.input_photo_discr_acc.keys(), self.input_photo_discr_acc.values())]
|
|
s_d3_acc = [tf.summary.scalar("discriminator/output_photo_discr_acc/" + key, val)
|
|
for key, val in zip(self.output_photo_discr_acc.keys(), self.output_photo_discr_acc.values())]
|
|
s_d_acc = tf.summary.scalar("discriminator/discr_acc", self.discr_acc)
|
|
s_d_acc_g = tf.summary.scalar("discriminator/discr_acc", self.gener_acc)
|
|
self.summary_discriminator_acc = tf.summary.merge(s_d1_acc+s_d2_acc+s_d3_acc+[s_d_acc])
|
|
|
|
|
|
s_i1 = tf.summary.scalar("image_loss/photo", self.img_loss_photo)
|
|
s_i = tf.summary.scalar("image_loss/loss", self.img_loss)
|
|
self.summary_image_loss = tf.summary.merge([s_i1 + s_i])
|
|
|
|
|
|
s_f1 = tf.summary.scalar("feature_loss/photo", self.feature_loss_photo)
|
|
s_f = tf.summary.scalar("feature_loss/loss", self.feature_loss)
|
|
self.summary_feature_loss = tf.summary.merge([s_f1 + s_f])
|
|
|
|
self.summary_merged_all = tf.summary.merge_all()
|
|
self.writer = tf.summary.FileWriter(self.logs_dir, self.sess.graph)
|
|
else:
|
|
|
|
with tf.name_scope('placeholder'):
|
|
self.input_photo = tf.placeholder(dtype=tf.float32,
|
|
shape=[self.batch_size, None, None, 3],
|
|
name='photo')
|
|
|
|
|
|
|
|
self.input_photo_features = encoder(image=self.input_photo,
|
|
options=self.options,
|
|
reuse=False)
|
|
|
|
|
|
self.output_photo = decoder(features=self.input_photo_features,
|
|
options=self.options,
|
|
reuse=False)
|
|
|
|
def train(self, args, ckpt_nmbr=None):
|
|
|
|
augmentor = img_augm.Augmentor(crop_size=[self.options.image_size, self.options.image_size],
|
|
vertical_flip_prb=0.,
|
|
hsv_augm_prb=1.0,
|
|
hue_augm_shift=0.05,
|
|
saturation_augm_shift=0.05, saturation_augm_scale=0.05,
|
|
value_augm_shift=0.05, value_augm_scale=0.05, )
|
|
content_dataset_places = prepare_dataset.PlacesDataset(path_to_dataset=self.options.path_to_content_dataset)
|
|
art_dataset = prepare_dataset.ArtDataset(path_to_art_dataset=self.options.path_to_art_dataset)
|
|
|
|
|
|
|
|
q_art = multiprocessing.Queue(maxsize=10)
|
|
q_content = multiprocessing.Queue(maxsize=10)
|
|
jobs = []
|
|
for i in range(5):
|
|
p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker,
|
|
args=(q_content, augmentor, self.batch_size, i))
|
|
p.start()
|
|
jobs.append(p)
|
|
|
|
p = multiprocessing.Process(target=art_dataset.initialize_batch_worker,
|
|
args=(q_art, augmentor, self.batch_size, i))
|
|
p.start()
|
|
jobs.append(p)
|
|
print("Processes are started.")
|
|
time.sleep(3)
|
|
|
|
|
|
init_op = tf.global_variables_initializer()
|
|
self.sess.run(init_op)
|
|
print("Start training.")
|
|
|
|
if self.load(self.checkpoint_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
if self.load(self.checkpoint_long_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
print(" [!] Load failed...")
|
|
|
|
|
|
win_rate = args.discr_success_rate
|
|
discr_success = args.discr_success_rate
|
|
alpha = 0.05
|
|
|
|
for step in tqdm(range(self.initial_step, self.options.total_steps+1),
|
|
initial=self.initial_step,
|
|
total=self.options.total_steps):
|
|
|
|
while q_art.empty() or q_content.empty():
|
|
pass
|
|
batch_art = q_art.get()
|
|
batch_content = q_content.get()
|
|
|
|
if discr_success >= win_rate:
|
|
|
|
_, summary_all, gener_acc_ = self.sess.run(
|
|
[self.g_optim_step, self.summary_merged_all, self.gener_acc],
|
|
feed_dict={
|
|
self.input_painting: normalize_arr_of_imgs(batch_art['image']),
|
|
self.input_photo: normalize_arr_of_imgs(batch_content['image']),
|
|
self.lr: self.options.lr
|
|
})
|
|
discr_success = discr_success * (1. - alpha) + alpha * (1. - gener_acc_)
|
|
else:
|
|
|
|
_, summary_all, discr_acc_ = self.sess.run(
|
|
[self.d_optim_step, self.summary_merged_all, self.discr_acc],
|
|
feed_dict={
|
|
self.input_painting: normalize_arr_of_imgs(batch_art['image']),
|
|
self.input_photo: normalize_arr_of_imgs(batch_content['image']),
|
|
self.lr: self.options.lr
|
|
})
|
|
|
|
discr_success = discr_success * (1. - alpha) + alpha * discr_acc_
|
|
self.writer.add_summary(summary_all, step * self.batch_size)
|
|
|
|
if step % self.options.save_freq == 0 and step > self.initial_step:
|
|
self.save(step)
|
|
|
|
|
|
if step % 15000 == 0 and step > self.initial_step:
|
|
self.save(step, is_long=True)
|
|
|
|
if step % 500 == 0:
|
|
output_paintings_, output_photos_= self.sess.run(
|
|
[self.input_painting, self.output_photo],
|
|
feed_dict={
|
|
self.input_painting: normalize_arr_of_imgs(batch_art['image']),
|
|
self.input_photo: normalize_arr_of_imgs(batch_content['image']),
|
|
self.lr: self.options.lr
|
|
})
|
|
|
|
save_batch(input_painting_batch=batch_art['image'],
|
|
input_photo_batch=batch_content['image'],
|
|
output_painting_batch=denormalize_arr_of_imgs(output_paintings_),
|
|
output_photo_batch=denormalize_arr_of_imgs(output_photos_),
|
|
filepath='%s/step_%d.jpg' % (self.sample_dir, step))
|
|
print("Training is finished. Terminate jobs.")
|
|
for p in jobs:
|
|
p.join()
|
|
p.terminate()
|
|
|
|
print("Done.")
|
|
print("Does the sys.exit() made this process to exit ??")
|
|
sys.exit()
|
|
|
|
|
|
def inference_video(self, args, path_to_folder, to_save_dir=None, resize_to_original=True,
|
|
use_time_smooth_randomness=True, ckpt_nmbr=None,file_suffix= "_stylized"):
|
|
"""
|
|
Run inference on the video frames. Original aspect ratio will be preserved.
|
|
Args:
|
|
args:
|
|
path_to_folder: path to the folder with frames from the video
|
|
to_save_dir:
|
|
resize_to_original:
|
|
use_time_smooth_randomness: change the random vector
|
|
which is added to the bottleneck features linearly over tim
|
|
|
|
Returns:
|
|
|
|
"""
|
|
init_op = tf.global_variables_initializer()
|
|
self.sess.run(init_op)
|
|
print("Start inference.")
|
|
|
|
if self.load(self.checkpoint_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
if self.load(self.checkpoint_long_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
print(" [!] Load failed...")
|
|
|
|
|
|
if to_save_dir is None:
|
|
to_save_dir = os.path.join(self.root_dir, self.model_name,
|
|
'inference_ckpt%d_sz%d' % (self.initial_step, self.image_size))
|
|
|
|
if not os.path.exists(to_save_dir):
|
|
os.makedirs(to_save_dir)
|
|
|
|
image_paths = sorted(os.listdir(path_to_folder))
|
|
num_images = len(image_paths)
|
|
for img_idx, img_name in enumerate(tqdm(image_paths)):
|
|
|
|
img_path = os.path.join(path_to_folder, img_name)
|
|
img = scipy.misc.imread(img_path, mode='RGB')
|
|
img_shape = img.shape[:2]
|
|
|
|
scale_mult = self.image_size / np.min(img_shape)
|
|
new_shape = (np.array(img_shape, dtype=float) * scale_mult).astype(int)
|
|
|
|
img = scipy.misc.imresize(img, size=new_shape)
|
|
|
|
img = np.expand_dims(img, axis=0)
|
|
|
|
if use_time_smooth_randomness and img_idx == 0:
|
|
features_delta = self.sess.run(self.labels_to_concatenate_to_features,
|
|
feed_dict={
|
|
self.input_photo: normalize_arr_of_imgs(img),
|
|
})
|
|
features_delta_start = features_delta + np.random.random(size=features_delta.shape) * 0.5 - 0.25
|
|
features_delta_start = features_delta_start.clip(0, 1000)
|
|
print('features_delta_start.shape=', features_delta_start.shape)
|
|
features_delta_end = features_delta + np.random.random(size=features_delta.shape) * 0.5 - 0.25
|
|
features_delta_end = features_delta_end.clip(0, 1000)
|
|
step = (features_delta_end - features_delta_start) / (num_images - 1)
|
|
|
|
feed_dict = {
|
|
self.input_painting: normalize_arr_of_imgs(img),
|
|
self.input_photo: normalize_arr_of_imgs(img),
|
|
self.lr: self.options.lr
|
|
}
|
|
if use_time_smooth_randomness:
|
|
pass
|
|
|
|
img = self.sess.run(self.output_photo, feed_dict=feed_dict)
|
|
|
|
img = img[0]
|
|
img = denormalize_arr_of_imgs(img)
|
|
if resize_to_original:
|
|
img = scipy.misc.imresize(img, size=img_shape)
|
|
else:
|
|
pass
|
|
|
|
scipy.misc.imsave(os.path.join(to_save_dir, img_name[:-4] + file_suffix +".jpg"), img)
|
|
|
|
print("Inference is finished.")
|
|
|
|
def inference(self, args, path_to_folder, to_save_dir=None, resize_to_original=True,
|
|
ckpt_nmbr=None,file_suffix= "_stylized"):
|
|
|
|
init_op = tf.global_variables_initializer()
|
|
self.sess.run(init_op)
|
|
print("Start inference.")
|
|
|
|
if self.load(self.checkpoint_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
if self.load(self.checkpoint_long_dir, ckpt_nmbr):
|
|
print(" [*] Load SUCCESS")
|
|
else:
|
|
print(" [!] Load failed...")
|
|
|
|
sys.exit()
|
|
|
|
|
|
if to_save_dir is None:
|
|
to_save_dir = os.path.join(self.root_dir, self.model_name,
|
|
'inference_ckpt%d_sz%d' % (self.initial_step, self.image_size))
|
|
|
|
if not os.path.exists(to_save_dir):
|
|
os.makedirs(to_save_dir)
|
|
|
|
names = []
|
|
for d in path_to_folder:
|
|
names += glob(os.path.join(d, '*'))
|
|
names = [x for x in names if os.path.basename(x)[0] != '.']
|
|
names.sort()
|
|
for img_idx, img_path in enumerate(tqdm(names)):
|
|
img = scipy.misc.imread(img_path, mode='RGB')
|
|
img_shape = img.shape[:2]
|
|
|
|
|
|
alpha = float(self.image_size) / float(min(img_shape))
|
|
img = scipy.misc.imresize(img, size=alpha)
|
|
img = np.expand_dims(img, axis=0)
|
|
|
|
img = self.sess.run(
|
|
self.output_photo,
|
|
feed_dict={
|
|
self.input_photo: normalize_arr_of_imgs(img),
|
|
})
|
|
|
|
img = img[0]
|
|
img = denormalize_arr_of_imgs(img)
|
|
if resize_to_original:
|
|
img = scipy.misc.imresize(img, size=img_shape)
|
|
else:
|
|
pass
|
|
img_name = os.path.basename(img_path)
|
|
|
|
scipy.misc.imsave(os.path.join(to_save_dir, img_name[:-4] + file_suffix +".jpg"), img)
|
|
|
|
print("Inference is finished.")
|
|
|
|
def save(self, step, is_long=False):
|
|
if not os.path.exists(self.checkpoint_dir):
|
|
os.makedirs(self.checkpoint_dir)
|
|
if is_long:
|
|
self.saver_long.save(self.sess,
|
|
os.path.join(self.checkpoint_long_dir, self.model_name+'_%d.ckpt' % step),
|
|
global_step=step)
|
|
else:
|
|
self.saver.save(self.sess,
|
|
os.path.join(self.checkpoint_dir, self.model_name + '_%d.ckpt' % step),
|
|
global_step=step)
|
|
|
|
def load(self, checkpoint_dir, ckpt_nmbr=None):
|
|
if ckpt_nmbr:
|
|
if len([x for x in os.listdir(checkpoint_dir) if ("ckpt-" + str(ckpt_nmbr)) in x]) > 0:
|
|
print(" [*] Reading checkpoint %d from folder %s." % (ckpt_nmbr, checkpoint_dir))
|
|
ckpt_name = [x for x in os.listdir(checkpoint_dir) if ("ckpt-" + str(ckpt_nmbr)) in x][0]
|
|
ckpt_name = '.'.join(ckpt_name.split('.')[:-1])
|
|
self.initial_step = ckpt_nmbr
|
|
print("Load checkpoint %s. Initial step: %s." % (ckpt_name, self.initial_step))
|
|
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
print(" [*] Reading latest checkpoint from folder %s." % (checkpoint_dir))
|
|
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
|
if ckpt and ckpt.model_checkpoint_path:
|
|
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
|
self.initial_step = int(ckpt_name.split("_")[-1].split(".")[0])
|
|
print("Load checkpoint %s. Initial step: %s." % (ckpt_name, self.initial_step))
|
|
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
|
return True
|
|
else:
|
|
return False
|
|
|