demo-painttransformer / gradio_painttransformer.py
jaekookang
try sudo
caf2f8e
raw
history blame
2.65 kB
'''PaintTransformer Demo
- 2021-12-21 first created
- See: https://github.com/wzmsltw/PaintTransformer
'''
import os
os.system('sudo apt-get update')
os.system('sudo apt-get -y install libgl1-mesa-glx')
import cv2
import network
from time import time
from glob import glob
from loguru import logger
import gradio as gr
import paddle
import render_utils
import render_parallel
import render_serial
# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cpu' if GPU_ID == '-1' else f'cuda:{GPU_ID}'
examples = sorted(glob(os.path.join('input', '*.jpg')))
WIDTH = 512
HEIGHT = 512
STROKE_NUM = 8
FPS = 10
# ---------- Logger ----------
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')
# ---------- Model ----------
MODEL_FILE = 'paint_best.pdparams'
if not os.path.exists(MODEL_FILE):
os.system('gdown --id 1G0O81qSvGp0kFCgyaQHmPygbVHFi1--q')
logger.info('model downloaded')
else:
logger.info('model already exists')
paddle.set_device(DEVICE)
net_g = network.Painter(5, STROKE_NUM, 256, 8, 3, 3)
net_g.set_state_dict(paddle.load(MODEL_FILE))
net_g.eval()
for param in net_g.parameters():
param.stop_gradient = True
brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L')
brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L')
meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0)
def predict(image_file):
original_img = render_utils.read_img(image_file, 'RGB', WIDTH, HEIGHT)
logger.info(f'--- image loaded & resized {WIDTH}x{HEIGHT}')
logger.info('--- doing inference...')
t0 = time()
final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes)
logger.info(f'--- inference took {time() - t0:.4f} sec')
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS,
(WIDTH, HEIGHT))
for idx, frame in enumerate(final_result_list):
out.write(frame)
out.release()
logger.info('--- animation generated')
return 'output.mp4'
iface = gr.Interface(
predict,
title='🎨 Paint Transformer',
description='This demo converts an image into a sequence of painted images (animation)',
inputs=[
gr.inputs.Image(label='Input image', type='filepath')
],
outputs=[
gr.outputs.Video(label='Output animation', type='mp4')
],
examples=examples,
article='<p style="text-align:center">Original work: <a href="https://github.com/wzmsltw/PaintTransformer">PaintTransformer</a></p>'
)
iface.launch(debug=True)