Spaces:
Running
Running
import gradio as gr | |
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import argparse | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gc | |
from baseline.DRL.actor import * | |
from baseline.Renderer.stroke_gen import * | |
from baseline.Renderer.model import * | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
width = 128 | |
actor_path = 'ckpts/actor.pkl' | |
renderer_path = 'ckpts/renderer.pkl' | |
# | |
divide = 4 | |
canvas_cnt = divide * divide | |
Decoder = FCN() | |
Decoder.load_state_dict(torch.load(renderer_path)) | |
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13 | |
actor.load_state_dict(torch.load(actor_path)) | |
actor = actor.to(device).eval() | |
Decoder = Decoder.to(device).eval() | |
decoders = {"Default": Decoder} | |
actors = {"Default": actor} | |
def decode(x, canvas, decoder = Decoder): # b * (10 + 3) | |
x = x.view(-1, 10 + 3) | |
stroke = 1 - decoder(x[:, :10]) | |
stroke = stroke.view(-1, width, width, 1) | |
color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) | |
stroke = stroke.permute(0, 3, 1, 2) | |
color_stroke = color_stroke.permute(0, 3, 1, 2) | |
stroke = stroke.view(-1, 5, 1, width, width) | |
color_stroke = color_stroke.view(-1, 5, 3, width, width) | |
res = [] | |
for i in range(5): | |
canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] | |
res.append(canvas) | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return canvas, res | |
def small2large(x): | |
# (d * d, width, width) -> (d * width, d * width) | |
x = x.reshape(divide, divide, width, width, -1) | |
x = np.transpose(x, (0, 2, 1, 3, 4)) | |
x = x.reshape(divide * width, divide * width, -1) | |
return x | |
def large2small(x): | |
# (d * width, d * width) -> (d * d, width, width) | |
x = x.reshape(divide, width, divide, width, 3) | |
x = np.transpose(x, (0, 2, 1, 3, 4)) | |
x = x.reshape(canvas_cnt, width, width, 3) | |
return x | |
def smooth(img): | |
def smooth_pix(img, tx, ty): | |
if tx == divide * width - 1 or ty == divide * width - 1 or tx == 0 or ty == 0: | |
return img | |
img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9 | |
return img | |
for p in range(divide): | |
for q in range(divide): | |
x = p * width | |
y = q * width | |
for k in range(width): | |
img = smooth_pix(img, x + k, y + width - 1) | |
if q != divide - 1: | |
img = smooth_pix(img, x + k, y + width) | |
for k in range(width): | |
img = smooth_pix(img, x + width - 1, y + k) | |
if p != divide - 1: | |
img = smooth_pix(img, x + width, y + k) | |
return img | |
def save_img(res, imgid, origin_shape, output_name, divide=False): | |
output = res.detach().cpu().numpy() # d * d, 3, width, width | |
output = np.transpose(output, (0, 2, 3, 1)) | |
if divide: | |
output = small2large(output) | |
output = smooth(output) | |
else: | |
output = output[0] | |
output = (output * 255).astype('uint8') | |
output = cv2.resize(output, origin_shape) | |
cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output) | |
def paint_img(img, max_step = 40, model_choices = "Default"): | |
Decoder = decoders[model_choices] | |
actor = actors[model_choices] | |
max_step = int(max_step) | |
# imgid = 0 | |
# output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0') | |
# os.makedirs(output_name, exist_ok= True) | |
# img = cv2.imread(args.img, cv2.IMREAD_COLOR) | |
origin_shape = (img.shape[1], img.shape[0]) | |
patch_img = cv2.resize(img, (width * divide, width * divide)) | |
patch_img = large2small(patch_img) | |
patch_img = np.transpose(patch_img, (0, 3, 1, 2)) | |
patch_img = torch.tensor(patch_img).to(device).float() / 255. | |
img = cv2.resize(img, (width, width)) | |
img = img.reshape(1, width, width, 3) | |
img = np.transpose(img, (0, 3, 1, 2)) | |
img = torch.tensor(img).to(device).float() / 255. | |
T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device) | |
coord = torch.zeros([1, 2, width, width]) | |
for i in range(width): | |
for j in range(width): | |
coord[0, 0, i, j] = i / (width - 1.) | |
coord[0, 1, i, j] = j / (width - 1.) | |
coord = coord.to(device) # Coordconv | |
canvas = torch.zeros([1, 3, width, width]).to(device) | |
with torch.no_grad(): | |
if divide != 1: | |
max_step = max_step // 2 | |
for i in range(max_step): | |
stepnum = T * i / max_step | |
actions = actor(torch.cat([canvas, img, stepnum, coord], 1)) | |
canvas, res = decode(actions, canvas, Decoder) | |
for j in range(5): | |
# save_img(res[j], imgid) | |
# imgid += 1 | |
output = res[j].detach().cpu().numpy() # d * d, 3, width, width | |
output = np.transpose(output, (0, 2, 3, 1)) | |
output = output[0] | |
output = (output * 255).astype('uint8') | |
output = cv2.resize(output, origin_shape) | |
yield output | |
if divide != 1: | |
canvas = canvas[0].detach().cpu().numpy() | |
canvas = np.transpose(canvas, (1, 2, 0)) | |
canvas = cv2.resize(canvas, (width * divide, width * divide)) | |
canvas = large2small(canvas) | |
canvas = np.transpose(canvas, (0, 3, 1, 2)) | |
canvas = torch.tensor(canvas).to(device).float() | |
coord = coord.expand(canvas_cnt, 2, width, width) | |
T = T.expand(canvas_cnt, 1, width, width) | |
for i in range(max_step): | |
stepnum = T * i / max_step | |
actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1)) | |
canvas, res = decode(actions, canvas, Decoder) | |
# print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean())) | |
for j in range(5): | |
# save_img(res[j], imgid, True) | |
# imgid += 1 | |
output = res[j].detach().cpu().numpy() # d * d, 3, width, width | |
output = np.transpose(output, (0, 2, 3, 1)) | |
output = small2large(output) | |
output = smooth(output) | |
output = (output * 255).astype('uint8') | |
output = cv2.resize(output, origin_shape) | |
yield output | |
yield output | |
def load_model_if_needed(choice: str): | |
# global Decoder, actor | |
if choice == "Default": | |
actor_path = 'ckpts/actor.pkl' | |
renderer_path = 'ckpts/renderer.pkl' | |
elif choice == "Triangle": | |
actor_path = 'ckpts/actor_triangle.pkl' | |
renderer_path = 'ckpts/triangle.pkl' | |
elif choice == "Round": | |
actor_path = 'ckpts/actor_round.pkl' | |
renderer_path = 'ckpts/round.pkl' | |
else: | |
actor_path = 'ckpts/actor_notrans.pkl' | |
renderer_path = 'ckpts/bezierwotrans.pkl' | |
if choice not in decoders: | |
Decoder = FCN() | |
Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu")) | |
Decoder = Decoder.to(device).eval() | |
decoders[choice] = Decoder | |
if choice not in actors: | |
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13 | |
actor.load_state_dict(torch.load(actor_path, map_location= "cpu")) | |
actor = actor.to(device).eval() | |
actors[choice] = actor | |
from typing import Generator | |
def wrapper(func): | |
event:Generator = range(0) | |
def inner(*args, **kwargs): | |
nonlocal event | |
val = args[0] | |
if val == "Cancel": | |
args_ = tuple(x for i,x in enumerate(args) if i > 0) | |
event = func(*args_, **kwargs) | |
yield from event | |
else: | |
try: | |
event.close() | |
yield | |
except: | |
pass | |
return inner | |
examples = [ | |
["image/chaoyue.png"], | |
["image/degang.png"], | |
["image/JayChou.png"], | |
["image/Leslie.png"], | |
["image/mayun.png"], | |
] | |
output = gr.Image(label="Painting Result") | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input image") | |
with gr.Row(): | |
step = gr.Slider(20, 100, value= 40, step = 1, label= 'Painting step') | |
with gr.Row(): | |
dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice') | |
with gr.Row(): | |
with gr.Column(): | |
clr_btn = gr.ClearButton([input_image, output], variant= "stop") | |
with gr.Column(): | |
translate_btn = gr.Button(value="Paint", variant="primary") | |
with gr.Column(): | |
output.render() | |
dropdown.select(load_model_if_needed, dropdown) | |
click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\ | |
.then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\ | |
.then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn) | |
clr_btn.click(None, None, cancels=[click_event]) | |
examples = gr.Examples(examples=examples, | |
inputs=[input_image], cache_examples = False) | |
# demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples) | |
demo.queue(default_concurrency_limit= 4) | |
demo.launch(server_name="0.0.0.0", ) |