Spaces:
Runtime error
Runtime error
Duplicate from hylee/White-box-Cartoonization
Browse filesCo-authored-by: hylee <hylee@users.noreply.huggingface.co>
- .gitattributes +29 -0
- README.md +15 -0
- app.py +108 -0
- packages.txt +2 -0
- requirements.txt +5 -0
- wbc/cartoonize.py +112 -0
- wbc/guided_filter.py +87 -0
- wbc/network.py +62 -0
- wbc/saved_models/checkpoint +3 -0
- wbc/saved_models/model-33999.data-00000-of-00001 +3 -0
- wbc/saved_models/model-33999.index +0 -0
.gitattributes
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
python_version: 3.7
|
3 |
+
title: White Box Cartoonization
|
4 |
+
emoji: 📚
|
5 |
+
colorFrom: purple
|
6 |
+
colorTo: green
|
7 |
+
sdk: gradio
|
8 |
+
sdk_version: 2.9.4
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
license: apache-2.0
|
12 |
+
duplicated_from: hylee/White-box-Cartoonization
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
import argparse
|
5 |
+
import functools
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import sys
|
9 |
+
from typing import Callable
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import huggingface_hub
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
|
17 |
+
from io import BytesIO
|
18 |
+
from wbc.cartoonize import Cartoonize
|
19 |
+
|
20 |
+
ORIGINAL_REPO_URL = 'https://github.com/SystemErrorWang/White-box-Cartoonization'
|
21 |
+
TITLE = 'SystemErrorWang/White-box-Cartoonization'
|
22 |
+
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
|
23 |
+
|
24 |
+
"""
|
25 |
+
ARTICLE = """
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
|
30 |
+
def compress_UUID():
|
31 |
+
'''
|
32 |
+
根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
|
33 |
+
包括:[0-9a-zA-Z\-_]共64个
|
34 |
+
长度:(32-2)/3*2=20
|
35 |
+
备注:可在地球上人zhi人都用,使用100年不重复(2^120)
|
36 |
+
:return:String
|
37 |
+
'''
|
38 |
+
row = str(uuid.uuid4()).replace('-', '')
|
39 |
+
safe_code = ''
|
40 |
+
for i in range(10):
|
41 |
+
enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10)
|
42 |
+
safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)])
|
43 |
+
safe_code = safe_code.replace('-', '')
|
44 |
+
return safe_code
|
45 |
+
|
46 |
+
|
47 |
+
def parse_args() -> argparse.Namespace:
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument('--device', type=str, default='cpu')
|
50 |
+
parser.add_argument('--theme', type=str)
|
51 |
+
parser.add_argument('--live', action='store_true')
|
52 |
+
parser.add_argument('--share', action='store_true')
|
53 |
+
parser.add_argument('--port', type=int)
|
54 |
+
parser.add_argument('--disable-queue',
|
55 |
+
dest='enable_queue',
|
56 |
+
action='store_false')
|
57 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
58 |
+
parser.add_argument('--allow-screenshot', action='store_true')
|
59 |
+
return parser.parse_args()
|
60 |
+
|
61 |
+
def run(
|
62 |
+
image,
|
63 |
+
cartoonize : Cartoonize
|
64 |
+
) -> tuple[PIL.Image.Image]:
|
65 |
+
|
66 |
+
out_path = compress_UUID()+'.png'
|
67 |
+
cartoonize.run_sigle(image.name, out_path)
|
68 |
+
|
69 |
+
return PIL.Image.open(out_path)
|
70 |
+
|
71 |
+
|
72 |
+
def main():
|
73 |
+
gr.close_all()
|
74 |
+
|
75 |
+
args = parse_args()
|
76 |
+
|
77 |
+
cartoonize = Cartoonize(os.path.join(os.path.dirname(os.path.abspath(__file__)),'wbc/saved_models/'))
|
78 |
+
|
79 |
+
func = functools.partial(run, cartoonize=cartoonize)
|
80 |
+
func = functools.update_wrapper(func, run)
|
81 |
+
|
82 |
+
gr.Interface(
|
83 |
+
func,
|
84 |
+
[
|
85 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
86 |
+
],
|
87 |
+
[
|
88 |
+
gr.outputs.Image(
|
89 |
+
type='pil',
|
90 |
+
label='Result'),
|
91 |
+
],
|
92 |
+
# examples=examples,
|
93 |
+
theme=args.theme,
|
94 |
+
title=TITLE,
|
95 |
+
description=DESCRIPTION,
|
96 |
+
article=ARTICLE,
|
97 |
+
allow_screenshot=args.allow_screenshot,
|
98 |
+
allow_flagging=args.allow_flagging,
|
99 |
+
live=args.live,
|
100 |
+
).launch(
|
101 |
+
enable_queue=args.enable_queue,
|
102 |
+
server_port=args.port,
|
103 |
+
share=args.share,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == '__main__':
|
108 |
+
main()
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python-headless==4.5.5.62
|
2 |
+
Pillow==9.0.1
|
3 |
+
scipy==1.7.3
|
4 |
+
tensorflow-gpu==1.14.0
|
5 |
+
scikit-image==0.14.5
|
wbc/cartoonize.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
import wbc.network as network
|
6 |
+
import wbc.guided_filter as guided_filter
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def resize_crop(image):
|
11 |
+
h, w, c = np.shape(image)
|
12 |
+
if min(h, w) > 720:
|
13 |
+
if h > w:
|
14 |
+
h, w = int(720 * h / w), 720
|
15 |
+
else:
|
16 |
+
h, w = 720, int(720 * w / h)
|
17 |
+
image = cv2.resize(image, (w, h),
|
18 |
+
interpolation=cv2.INTER_AREA)
|
19 |
+
h, w = (h // 8) * 8, (w // 8) * 8
|
20 |
+
image = image[:h, :w, :]
|
21 |
+
return image
|
22 |
+
|
23 |
+
|
24 |
+
def cartoonize(load_folder, save_folder, model_path):
|
25 |
+
print(model_path)
|
26 |
+
input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
|
27 |
+
network_out = network.unet_generator(input_photo)
|
28 |
+
final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)
|
29 |
+
|
30 |
+
all_vars = tf.trainable_variables()
|
31 |
+
gene_vars = [var for var in all_vars if 'generator' in var.name]
|
32 |
+
saver = tf.train.Saver(var_list=gene_vars)
|
33 |
+
|
34 |
+
config = tf.ConfigProto()
|
35 |
+
config.gpu_options.allow_growth = True
|
36 |
+
sess = tf.Session(config=config)
|
37 |
+
|
38 |
+
sess.run(tf.global_variables_initializer())
|
39 |
+
saver.restore(sess, tf.train.latest_checkpoint(model_path))
|
40 |
+
name_list = os.listdir(load_folder)
|
41 |
+
for name in tqdm(name_list):
|
42 |
+
try:
|
43 |
+
load_path = os.path.join(load_folder, name)
|
44 |
+
save_path = os.path.join(save_folder, name)
|
45 |
+
image = cv2.imread(load_path)
|
46 |
+
image = resize_crop(image)
|
47 |
+
batch_image = image.astype(np.float32) / 127.5 - 1
|
48 |
+
batch_image = np.expand_dims(batch_image, axis=0)
|
49 |
+
output = sess.run(final_out, feed_dict={input_photo: batch_image})
|
50 |
+
output = (np.squeeze(output) + 1) * 127.5
|
51 |
+
output = np.clip(output, 0, 255).astype(np.uint8)
|
52 |
+
cv2.imwrite(save_path, output)
|
53 |
+
except:
|
54 |
+
print('cartoonize {} failed'.format(load_path))
|
55 |
+
|
56 |
+
|
57 |
+
class Cartoonize:
|
58 |
+
def __init__(self, model_path):
|
59 |
+
print(model_path)
|
60 |
+
self.input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
|
61 |
+
network_out = network.unet_generator(self.input_photo)
|
62 |
+
self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3)
|
63 |
+
|
64 |
+
all_vars = tf.trainable_variables()
|
65 |
+
gene_vars = [var for var in all_vars if 'generator' in var.name]
|
66 |
+
saver = tf.train.Saver(var_list=gene_vars)
|
67 |
+
|
68 |
+
config = tf.ConfigProto()
|
69 |
+
config.gpu_options.allow_growth = True
|
70 |
+
self.sess = tf.Session(config=config)
|
71 |
+
|
72 |
+
self.sess.run(tf.global_variables_initializer())
|
73 |
+
saver.restore(self.sess, tf.train.latest_checkpoint(model_path))
|
74 |
+
|
75 |
+
def run(self, load_folder, save_folder):
|
76 |
+
name_list = os.listdir(load_folder)
|
77 |
+
for name in tqdm(name_list):
|
78 |
+
try:
|
79 |
+
load_path = os.path.join(load_folder, name)
|
80 |
+
save_path = os.path.join(save_folder, name)
|
81 |
+
image = cv2.imread(load_path)
|
82 |
+
image = resize_crop(image)
|
83 |
+
batch_image = image.astype(np.float32) / 127.5 - 1
|
84 |
+
batch_image = np.expand_dims(batch_image, axis=0)
|
85 |
+
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
|
86 |
+
output = (np.squeeze(output) + 1) * 127.5
|
87 |
+
output = np.clip(output, 0, 255).astype(np.uint8)
|
88 |
+
cv2.imwrite(save_path, output)
|
89 |
+
except:
|
90 |
+
print('cartoonize {} failed'.format(load_path))
|
91 |
+
|
92 |
+
def run_sigle(self, load_path, save_path):
|
93 |
+
try:
|
94 |
+
image = cv2.imread(load_path)
|
95 |
+
image = resize_crop(image)
|
96 |
+
batch_image = image.astype(np.float32) / 127.5 - 1
|
97 |
+
batch_image = np.expand_dims(batch_image, axis=0)
|
98 |
+
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
|
99 |
+
output = (np.squeeze(output) + 1) * 127.5
|
100 |
+
output = np.clip(output, 0, 255).astype(np.uint8)
|
101 |
+
cv2.imwrite(save_path, output)
|
102 |
+
except:
|
103 |
+
print('cartoonize {} failed'.format(load_path))
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
model_path = 'saved_models'
|
108 |
+
load_folder = 'test_images'
|
109 |
+
save_folder = 'cartoonized_images'
|
110 |
+
if not os.path.exists(save_folder):
|
111 |
+
os.mkdir(save_folder)
|
112 |
+
cartoonize(load_folder, save_folder, model_path)
|
wbc/guided_filter.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def tf_box_filter(x, r):
|
8 |
+
k_size = int(2*r+1)
|
9 |
+
ch = x.get_shape().as_list()[-1]
|
10 |
+
weight = 1/(k_size**2)
|
11 |
+
box_kernel = weight*np.ones((k_size, k_size, ch, 1))
|
12 |
+
box_kernel = np.array(box_kernel).astype(np.float32)
|
13 |
+
output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
|
14 |
+
return output
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def guided_filter(x, y, r, eps=1e-2):
|
19 |
+
|
20 |
+
x_shape = tf.shape(x)
|
21 |
+
#y_shape = tf.shape(y)
|
22 |
+
|
23 |
+
N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
|
24 |
+
|
25 |
+
mean_x = tf_box_filter(x, r) / N
|
26 |
+
mean_y = tf_box_filter(y, r) / N
|
27 |
+
cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
|
28 |
+
var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
|
29 |
+
|
30 |
+
A = cov_xy / (var_x + eps)
|
31 |
+
b = mean_y - A * mean_x
|
32 |
+
|
33 |
+
mean_A = tf_box_filter(A, r) / N
|
34 |
+
mean_b = tf_box_filter(b, r) / N
|
35 |
+
|
36 |
+
output = mean_A * x + mean_b
|
37 |
+
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def fast_guided_filter(lr_x, lr_y, hr_x, r=1, eps=1e-8):
|
43 |
+
|
44 |
+
#assert lr_x.shape.ndims == 4 and lr_y.shape.ndims == 4 and hr_x.shape.ndims == 4
|
45 |
+
|
46 |
+
lr_x_shape = tf.shape(lr_x)
|
47 |
+
#lr_y_shape = tf.shape(lr_y)
|
48 |
+
hr_x_shape = tf.shape(hr_x)
|
49 |
+
|
50 |
+
N = tf_box_filter(tf.ones((1, lr_x_shape[1], lr_x_shape[2], 1), dtype=lr_x.dtype), r)
|
51 |
+
|
52 |
+
mean_x = tf_box_filter(lr_x, r) / N
|
53 |
+
mean_y = tf_box_filter(lr_y, r) / N
|
54 |
+
cov_xy = tf_box_filter(lr_x * lr_y, r) / N - mean_x * mean_y
|
55 |
+
var_x = tf_box_filter(lr_x * lr_x, r) / N - mean_x * mean_x
|
56 |
+
|
57 |
+
A = cov_xy / (var_x + eps)
|
58 |
+
b = mean_y - A * mean_x
|
59 |
+
|
60 |
+
mean_A = tf.image.resize_images(A, hr_x_shape[1: 3])
|
61 |
+
mean_b = tf.image.resize_images(b, hr_x_shape[1: 3])
|
62 |
+
|
63 |
+
output = mean_A * hr_x + mean_b
|
64 |
+
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
import cv2
|
70 |
+
from tqdm import tqdm
|
71 |
+
|
72 |
+
input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
|
73 |
+
#input_superpixel = tf.placeholder(tf.float32, [16, 256, 256, 3])
|
74 |
+
output = guided_filter(input_photo, input_photo, 5, eps=1)
|
75 |
+
image = cv2.imread('output_figure1/cartoon2.jpg')
|
76 |
+
image = image/127.5 - 1
|
77 |
+
image = np.expand_dims(image, axis=0)
|
78 |
+
|
79 |
+
config = tf.ConfigProto()
|
80 |
+
config.gpu_options.allow_growth = True
|
81 |
+
sess = tf.Session(config=config)
|
82 |
+
sess.run(tf.global_variables_initializer())
|
83 |
+
|
84 |
+
out = sess.run(output, feed_dict={input_photo: image})
|
85 |
+
out = (np.squeeze(out)+1)*127.5
|
86 |
+
out = np.clip(out, 0, 255).astype(np.uint8)
|
87 |
+
cv2.imwrite('output_figure1/cartoon2_filter.jpg', out)
|
wbc/network.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow.contrib.slim as slim
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def resblock(inputs, out_channel=32, name='resblock'):
|
8 |
+
|
9 |
+
with tf.variable_scope(name):
|
10 |
+
|
11 |
+
x = slim.convolution2d(inputs, out_channel, [3, 3],
|
12 |
+
activation_fn=None, scope='conv1')
|
13 |
+
x = tf.nn.leaky_relu(x)
|
14 |
+
x = slim.convolution2d(x, out_channel, [3, 3],
|
15 |
+
activation_fn=None, scope='conv2')
|
16 |
+
|
17 |
+
return x + inputs
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False):
|
23 |
+
with tf.variable_scope(name, reuse=reuse):
|
24 |
+
|
25 |
+
x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
|
26 |
+
x0 = tf.nn.leaky_relu(x0)
|
27 |
+
|
28 |
+
x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None)
|
29 |
+
x1 = tf.nn.leaky_relu(x1)
|
30 |
+
x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None)
|
31 |
+
x1 = tf.nn.leaky_relu(x1)
|
32 |
+
|
33 |
+
x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None)
|
34 |
+
x2 = tf.nn.leaky_relu(x2)
|
35 |
+
x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None)
|
36 |
+
x2 = tf.nn.leaky_relu(x2)
|
37 |
+
|
38 |
+
for idx in range(num_blocks):
|
39 |
+
x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx))
|
40 |
+
|
41 |
+
x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None)
|
42 |
+
x2 = tf.nn.leaky_relu(x2)
|
43 |
+
|
44 |
+
h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
|
45 |
+
x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2))
|
46 |
+
x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None)
|
47 |
+
x3 = tf.nn.leaky_relu(x3)
|
48 |
+
x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
|
49 |
+
x3 = tf.nn.leaky_relu(x3)
|
50 |
+
|
51 |
+
h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
|
52 |
+
x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2))
|
53 |
+
x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None)
|
54 |
+
x4 = tf.nn.leaky_relu(x4)
|
55 |
+
x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
|
56 |
+
|
57 |
+
return x4
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
|
61 |
+
|
62 |
+
pass
|
wbc/saved_models/checkpoint
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
model_checkpoint_path: "model-33999"
|
2 |
+
all_model_checkpoint_paths: "model-33999"
|
3 |
+
all_model_checkpoint_paths: "model-37499"
|
wbc/saved_models/model-33999.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e2df1a5aa86faa4f979720bfc2436f79333a480876f8d6790b7671cf50fe75b
|
3 |
+
size 5868300
|
wbc/saved_models/model-33999.index
ADDED
Binary file (1.56 kB). View file
|
|