AdaIN / app.py
ariG23498's picture
ariG23498 HF staff
init
a7f8f41
raw
history blame
1.48 kB
import gradio as gr
import tensorflow as tf
import wget
enc_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-encoder.h5'
enc_filename = wget.download(enc_url)
dec_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-decoder.h5'
dec_filename = wget.download(dec_url)
encoder = tf.keras.models.load_model(enc_filename, compile=False)
decoder = tf.keras.models.load_model(dec_filename, compile=False)
def get_mean_std(tensor, epsilon=1e-5):
axes = [1, 2]
tensor_mean, tensor_var = tf.nn.moments(tensor, axes=axes, keepdims=True)
tensor_std = tf.sqrt(tensor_var + epsilon)
return tensor_mean, tensor_std
def ada_in(style, content, epsilon=1e-5):
c_mean, c_std = get_mean_std(content)
s_mean, s_std = get_mean_std(style)
t = s_std * (content - c_mean) / c_std + s_mean
return t
def load_resize(image):
image = tf.image.convert_image_dtype(image, dtype="float32")
image = tf.image.resize(image, (224, 224))
return image
def infer(style, content):
style = load_resize(style)
style = style[tf.newaxis, ...]
content = load_resize(content)
content = content[tf.newaxis, ...]
style_enc = encoder(style)
content_enc = encoder(content)
t = ada_in(style=style_enc, content=content_enc)
recons_image = decoder(t)
return recons_image[0].numpy()
iface = gr.Interface(
fn=infer,
inputs=[gr.inputs.Image(label="style"),
gr.inputs.Image(label="content")],
outputs="image").launch()