Neural_painting / app.py
minhnh's picture
Fix bug import wrong dep
e0b460b
raw
history blame
9.77 kB
import gradio as gr
import os
import cv2
import torch
import numpy as np
import argparse
import torch.nn as nn
import torch.nn.functional as F
import gc
from baseline.DRL.actor import *
from baseline.Renderer.stroke_gen import *
from baseline.Renderer.model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
width = 128
actor_path = 'ckpts/actor.pkl'
renderer_path = 'ckpts/renderer.pkl'
#
divide = 4
canvas_cnt = divide * divide
Decoder = FCN()
Decoder.load_state_dict(torch.load(renderer_path))
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
actor.load_state_dict(torch.load(actor_path))
actor = actor.to(device).eval()
Decoder = Decoder.to(device).eval()
decoders = {"Default": Decoder}
actors = {"Default": actor}
def decode(x, canvas, decoder = Decoder): # b * (10 + 3)
x = x.view(-1, 10 + 3)
stroke = 1 - decoder(x[:, :10])
stroke = stroke.view(-1, width, width, 1)
color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
stroke = stroke.permute(0, 3, 1, 2)
color_stroke = color_stroke.permute(0, 3, 1, 2)
stroke = stroke.view(-1, 5, 1, width, width)
color_stroke = color_stroke.view(-1, 5, 3, width, width)
res = []
for i in range(5):
canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
res.append(canvas)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return canvas, res
def small2large(x):
# (d * d, width, width) -> (d * width, d * width)
x = x.reshape(divide, divide, width, width, -1)
x = np.transpose(x, (0, 2, 1, 3, 4))
x = x.reshape(divide * width, divide * width, -1)
return x
def large2small(x):
# (d * width, d * width) -> (d * d, width, width)
x = x.reshape(divide, width, divide, width, 3)
x = np.transpose(x, (0, 2, 1, 3, 4))
x = x.reshape(canvas_cnt, width, width, 3)
return x
def smooth(img):
def smooth_pix(img, tx, ty):
if tx == divide * width - 1 or ty == divide * width - 1 or tx == 0 or ty == 0:
return img
img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9
return img
for p in range(divide):
for q in range(divide):
x = p * width
y = q * width
for k in range(width):
img = smooth_pix(img, x + k, y + width - 1)
if q != divide - 1:
img = smooth_pix(img, x + k, y + width)
for k in range(width):
img = smooth_pix(img, x + width - 1, y + k)
if p != divide - 1:
img = smooth_pix(img, x + width, y + k)
return img
def save_img(res, imgid, origin_shape, output_name, divide=False):
output = res.detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
if divide:
output = small2large(output)
output = smooth(output)
else:
output = output[0]
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
def paint_img(img, max_step = 40, model_choices = "Default"):
Decoder = decoders[model_choices]
actor = actors[model_choices]
max_step = int(max_step)
# imgid = 0
# output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
# os.makedirs(output_name, exist_ok= True)
# img = cv2.imread(args.img, cv2.IMREAD_COLOR)
origin_shape = (img.shape[1], img.shape[0])
patch_img = cv2.resize(img, (width * divide, width * divide))
patch_img = large2small(patch_img)
patch_img = np.transpose(patch_img, (0, 3, 1, 2))
patch_img = torch.tensor(patch_img).to(device).float() / 255.
img = cv2.resize(img, (width, width))
img = img.reshape(1, width, width, 3)
img = np.transpose(img, (0, 3, 1, 2))
img = torch.tensor(img).to(device).float() / 255.
T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device)
coord = torch.zeros([1, 2, width, width])
for i in range(width):
for j in range(width):
coord[0, 0, i, j] = i / (width - 1.)
coord[0, 1, i, j] = j / (width - 1.)
coord = coord.to(device) # Coordconv
canvas = torch.zeros([1, 3, width, width]).to(device)
with torch.no_grad():
if divide != 1:
max_step = max_step // 2
for i in range(max_step):
stepnum = T * i / max_step
actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
canvas, res = decode(actions, canvas, Decoder)
for j in range(5):
# save_img(res[j], imgid)
# imgid += 1
output = res[j].detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
output = output[0]
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
yield output
if divide != 1:
canvas = canvas[0].detach().cpu().numpy()
canvas = np.transpose(canvas, (1, 2, 0))
canvas = cv2.resize(canvas, (width * divide, width * divide))
canvas = large2small(canvas)
canvas = np.transpose(canvas, (0, 3, 1, 2))
canvas = torch.tensor(canvas).to(device).float()
coord = coord.expand(canvas_cnt, 2, width, width)
T = T.expand(canvas_cnt, 1, width, width)
for i in range(max_step):
stepnum = T * i / max_step
actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
canvas, res = decode(actions, canvas, Decoder)
# print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
for j in range(5):
# save_img(res[j], imgid, True)
# imgid += 1
output = res[j].detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
output = small2large(output)
output = smooth(output)
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
yield output
yield output
def load_model_if_needed(choice: str):
# global Decoder, actor
if choice == "Default":
actor_path = 'ckpts/actor.pkl'
renderer_path = 'ckpts/renderer.pkl'
elif choice == "Triangle":
actor_path = 'ckpts/actor_triangle.pkl'
renderer_path = 'ckpts/triangle.pkl'
elif choice == "Round":
actor_path = 'ckpts/actor_round.pkl'
renderer_path = 'ckpts/round.pkl'
else:
actor_path = 'ckpts/actor_notrans.pkl'
renderer_path = 'ckpts/bezierwotrans.pkl'
if choice not in decoders:
Decoder = FCN()
Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
Decoder = Decoder.to(device).eval()
decoders[choice] = Decoder
if choice not in actors:
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
actor = actor.to(device).eval()
actors[choice] = actor
from typing import Generator
def wrapper(func):
event:Generator = range(0)
def inner(*args, **kwargs):
nonlocal event
val = args[0]
if val == "Cancel":
args_ = tuple(x for i,x in enumerate(args) if i > 0)
event = func(*args_, **kwargs)
yield from event
else:
try:
event.close()
yield
except:
pass
return inner
examples = [
["image/chaoyue.png"],
["image/degang.png"],
["image/JayChou.png"],
["image/Leslie.png"],
["image/mayun.png"],
]
output = gr.Image(label="Painting Result")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input image")
with gr.Row():
step = gr.Slider(20, 100, value= 40, step = 1, label= 'Painting step')
with gr.Row():
dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice')
with gr.Row():
with gr.Column():
clr_btn = gr.ClearButton([input_image, output], variant= "stop")
with gr.Column():
translate_btn = gr.Button(value="Paint", variant="primary")
with gr.Column():
output.render()
dropdown.select(load_model_if_needed, dropdown)
click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
.then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\
.then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
clr_btn.click(None, None, cancels=[click_event])
examples = gr.Examples(examples=examples,
inputs=[input_image], cache_examples = False)
# demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
demo.queue(default_concurrency_limit= 4)
demo.launch(server_name="0.0.0.0", )