Spaces:
Build error
Build error
import os | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
import wbc.network as network | |
import wbc.guided_filter as guided_filter | |
from tqdm import tqdm | |
def resize_crop(image): | |
h, w, c = np.shape(image) | |
if min(h, w) > 720: | |
if h > w: | |
h, w = int(720 * h / w), 720 | |
else: | |
h, w = 720, int(720 * w / h) | |
image = cv2.resize(image, (w, h), | |
interpolation=cv2.INTER_AREA) | |
h, w = (h // 8) * 8, (w // 8) * 8 | |
image = image[:h, :w, :] | |
return image | |
def cartoonize(load_folder, save_folder, model_path): | |
print(model_path) | |
input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) | |
network_out = network.unet_generator(input_photo) | |
final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) | |
all_vars = tf.trainable_variables() | |
gene_vars = [var for var in all_vars if 'generator' in var.name] | |
saver = tf.train.Saver(var_list=gene_vars) | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
sess = tf.Session(config=config) | |
sess.run(tf.global_variables_initializer()) | |
saver.restore(sess, tf.train.latest_checkpoint(model_path)) | |
name_list = os.listdir(load_folder) | |
for name in tqdm(name_list): | |
try: | |
load_path = os.path.join(load_folder, name) | |
save_path = os.path.join(save_folder, name) | |
image = cv2.imread(load_path) | |
image = resize_crop(image) | |
batch_image = image.astype(np.float32) / 127.5 - 1 | |
batch_image = np.expand_dims(batch_image, axis=0) | |
output = sess.run(final_out, feed_dict={input_photo: batch_image}) | |
output = (np.squeeze(output) + 1) * 127.5 | |
output = np.clip(output, 0, 255).astype(np.uint8) | |
cv2.imwrite(save_path, output) | |
except: | |
print('cartoonize {} failed'.format(load_path)) | |
class Cartoonize: | |
def __init__(self, model_path): | |
print(model_path) | |
self.input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) | |
network_out = network.unet_generator(self.input_photo) | |
self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3) | |
all_vars = tf.trainable_variables() | |
gene_vars = [var for var in all_vars if 'generator' in var.name] | |
saver = tf.train.Saver(var_list=gene_vars) | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
self.sess = tf.Session(config=config) | |
self.sess.run(tf.global_variables_initializer()) | |
saver.restore(self.sess, tf.train.latest_checkpoint(model_path)) | |
def run(self, load_folder, save_folder): | |
name_list = os.listdir(load_folder) | |
for name in tqdm(name_list): | |
try: | |
load_path = os.path.join(load_folder, name) | |
save_path = os.path.join(save_folder, name) | |
image = cv2.imread(load_path) | |
image = resize_crop(image) | |
batch_image = image.astype(np.float32) / 127.5 - 1 | |
batch_image = np.expand_dims(batch_image, axis=0) | |
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) | |
output = (np.squeeze(output) + 1) * 127.5 | |
output = np.clip(output, 0, 255).astype(np.uint8) | |
cv2.imwrite(save_path, output) | |
except: | |
print('cartoonize {} failed'.format(load_path)) | |
def run_sigle(self, load_path, save_path): | |
try: | |
image = cv2.imread(load_path) | |
image = resize_crop(image) | |
batch_image = image.astype(np.float32) / 127.5 - 1 | |
batch_image = np.expand_dims(batch_image, axis=0) | |
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) | |
output = (np.squeeze(output) + 1) * 127.5 | |
output = np.clip(output, 0, 255).astype(np.uint8) | |
cv2.imwrite(save_path, output) | |
except: | |
print('cartoonize {} failed'.format(load_path)) | |
if __name__ == '__main__': | |
model_path = 'saved_models' | |
load_folder = 'test_images' | |
save_folder = 'cartoonized_images' | |
if not os.path.exists(save_folder): | |
os.mkdir(save_folder) | |
cartoonize(load_folder, save_folder, model_path) | |