Podtekatel's picture
Init commit
ad89c3c
import logging
import os
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_url, cached_download
from inference.face_detector import StatRetinaFaceDetector
from inference.model_pipeline import VSNetModelPipeline
from inference.onnx_model import ONNXModel
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
MODEL_IMG_SIZE = 512
usage_count = 0 # Based on hugging face logs
def load_model():
REPO_ID = "Podtekatel/Avatar2VSK"
FILENAME = "avatar2_260_ep_181.onnx"
global model
global pipeline
# Old model
model_path = cached_download(
hf_hub_url(REPO_ID, FILENAME), use_auth_token=os.getenv('HF_TOKEN')
)
model = ONNXModel(model_path)
pipeline = VSNetModelPipeline(model, StatRetinaFaceDetector(MODEL_IMG_SIZE), background_resize=1024, no_detected_resize=1024)
return model
load_model()
def inference(img):
img = np.array(img)
out_img = pipeline(img)
out_img = Image.fromarray(out_img)
global usage_count
usage_count += 1
logging.info(f'Usage count is {usage_count}')
return out_img
title = "Avatar 2 Style Transfer"
description = "Gradio Demo for Avatar: The Way of Water style transfer. To use it, simply upload your image, or click one of the examples to load them. Press ❀️ if you like this space or mention this repo on Reddit or Twitter!<br>" \
"""<table>
<tr>
<th><img src="file/static/input.jpg" alt="Input"/></th>
<th><img src="file/static/output.jpg" alt="Output" width="610" height="398"/></th>
</tr>
</table>
"""
article = "This model was trained on `Avatar: The Way of Water` movie. This model mainly focuses on faces stylization, Pay attention on this when uploads images. <br>" \
"" \
"Model pipeline which used in project is improved CartoonGAN.<br>" \
"This model was trained on RTX 2080 Ti 2 days with batch size 7.<br>" \
"Model weights 80 MB in ONNX fp32 format, infers 100 ms on GPU and 600 ms on CPU at 512x512 resolution.<br>" \
"My email contact: 'neuromancer.ai.lover@gmail.com'."
imgs_folder = 'demo'
examples = [[os.path.join(imgs_folder, img_filename)] for img_filename in sorted(os.listdir(imgs_folder))]
demo = gr.Interface(
fn=inference,
inputs=[gr.inputs.Image(type="pil")],
outputs=gr.outputs.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples)
demo.queue(concurrency_count=1)
demo.launch()