File size: 2,653 Bytes
9ff1108
 
 
 
 
 
 
 
caf2f8e
 
732466b
9ff1108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
'''PaintTransformer Demo

- 2021-12-21 first created
    - See: https://github.com/wzmsltw/PaintTransformer

'''

import os
os.system('sudo apt-get update')
os.system('sudo apt-get -y install libgl1-mesa-glx')

import cv2
import network
from time import time
from glob import glob
from loguru import logger
import gradio as gr

import paddle
import render_utils
import render_parallel
import render_serial

# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cpu' if GPU_ID == '-1' else f'cuda:{GPU_ID}'

examples = sorted(glob(os.path.join('input', '*.jpg')))
WIDTH = 512
HEIGHT = 512
STROKE_NUM = 8
FPS = 10

# ---------- Logger ----------
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')

# ---------- Model ----------
MODEL_FILE = 'paint_best.pdparams'
if not os.path.exists(MODEL_FILE):
    os.system('gdown --id 1G0O81qSvGp0kFCgyaQHmPygbVHFi1--q')
    logger.info('model downloaded')
else:
    logger.info('model already exists')

paddle.set_device(DEVICE)
net_g = network.Painter(5, STROKE_NUM, 256, 8, 3, 3)
net_g.set_state_dict(paddle.load(MODEL_FILE))
net_g.eval()
for param in net_g.parameters():
    param.stop_gradient = True

brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L')
brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L')
meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0)    

def predict(image_file):
    original_img = render_utils.read_img(image_file, 'RGB', WIDTH, HEIGHT)
    logger.info(f'--- image loaded & resized {WIDTH}x{HEIGHT}')

    logger.info('--- doing inference...')
    t0 = time()
    final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes)
    logger.info(f'--- inference took {time() - t0:.4f} sec')

    out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, 
                          (WIDTH, HEIGHT))
    for idx, frame in enumerate(final_result_list):
        out.write(frame)
    out.release()
    logger.info('--- animation generated')
    return 'output.mp4'

iface = gr.Interface(
    predict,
    title='🎨 Paint Transformer',
    description='This demo converts an image into a sequence of painted images (animation)',
    inputs=[
        gr.inputs.Image(label='Input image', type='filepath')
    ],
    outputs=[
        gr.outputs.Video(label='Output animation', type='mp4')
    ],
    examples=examples,
    article='<p style="text-align:center">Original work: <a href="https://github.com/wzmsltw/PaintTransformer">PaintTransformer</a></p>'
)   

iface.launch(debug=True)