|
import gradio as gr |
|
import tensorflow as tf |
|
import wget |
|
|
|
enc_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-encoder.h5' |
|
enc_filename = wget.download(enc_url) |
|
|
|
dec_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-decoder.h5' |
|
dec_filename = wget.download(dec_url) |
|
|
|
encoder = tf.keras.models.load_model(enc_filename, compile=False) |
|
decoder = tf.keras.models.load_model(dec_filename, compile=False) |
|
|
|
def get_mean_std(tensor, epsilon=1e-5): |
|
axes = [1, 2] |
|
tensor_mean, tensor_var = tf.nn.moments(tensor, axes=axes, keepdims=True) |
|
tensor_std = tf.sqrt(tensor_var + epsilon) |
|
return tensor_mean, tensor_std |
|
|
|
def ada_in(style, content, epsilon=1e-5): |
|
c_mean, c_std = get_mean_std(content) |
|
s_mean, s_std = get_mean_std(style) |
|
t = s_std * (content - c_mean) / c_std + s_mean |
|
return t |
|
|
|
def load_resize(image): |
|
image = tf.image.convert_image_dtype(image, dtype="float32") |
|
image = tf.image.resize(image, (224, 224)) |
|
return image |
|
|
|
def infer(style, content): |
|
style = load_resize(style) |
|
style = style[tf.newaxis, ...] |
|
content = load_resize(content) |
|
content = content[tf.newaxis, ...] |
|
|
|
style_enc = encoder(style) |
|
content_enc = encoder(content) |
|
|
|
t = ada_in(style=style_enc, content=content_enc) |
|
|
|
recons_image = decoder(t) |
|
return recons_image[0].numpy() |
|
|
|
iface = gr.Interface( |
|
fn=infer, |
|
inputs=[gr.inputs.Image(label="style"), |
|
gr.inputs.Image(label="content")], |
|
outputs="image").launch() |
|
|