import gradio as gr import os import cv2 import torch import numpy as np import argparse import torch.nn as nn import torch.nn.functional as F import gc from baseline.DRL.actor import * from baseline.Renderer.stroke_gen import * from baseline.Renderer.model import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") width = 128 actor_path = 'ckpts/actor.pkl' renderer_path = 'ckpts/renderer.pkl' # divide = 4 canvas_cnt = divide * divide Decoder = FCN() Decoder.load_state_dict(torch.load(renderer_path)) actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13 actor.load_state_dict(torch.load(actor_path)) actor = actor.to(device).eval() Decoder = Decoder.to(device).eval() decoders = {"Default": Decoder} actors = {"Default": actor} def decode(x, canvas, decoder = Decoder): # b * (10 + 3) x = x.view(-1, 10 + 3) stroke = 1 - decoder(x[:, :10]) stroke = stroke.view(-1, width, width, 1) color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) stroke = stroke.permute(0, 3, 1, 2) color_stroke = color_stroke.permute(0, 3, 1, 2) stroke = stroke.view(-1, 5, 1, width, width) color_stroke = color_stroke.view(-1, 5, 3, width, width) res = [] for i in range(5): canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] res.append(canvas) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return canvas, res def small2large(x): # (d * d, width, width) -> (d * width, d * width) x = x.reshape(divide, divide, width, width, -1) x = np.transpose(x, (0, 2, 1, 3, 4)) x = x.reshape(divide * width, divide * width, -1) return x def large2small(x): # (d * width, d * width) -> (d * d, width, width) x = x.reshape(divide, width, divide, width, 3) x = np.transpose(x, (0, 2, 1, 3, 4)) x = x.reshape(canvas_cnt, width, width, 3) return x def smooth(img): def smooth_pix(img, tx, ty): if tx == divide * width - 1 or ty == divide * width - 1 or tx == 0 or ty == 0: return img img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9 return img for p in range(divide): for q in range(divide): x = p * width y = q * width for k in range(width): img = smooth_pix(img, x + k, y + width - 1) if q != divide - 1: img = smooth_pix(img, x + k, y + width) for k in range(width): img = smooth_pix(img, x + width - 1, y + k) if p != divide - 1: img = smooth_pix(img, x + width, y + k) return img def save_img(res, imgid, origin_shape, output_name, divide=False): output = res.detach().cpu().numpy() # d * d, 3, width, width output = np.transpose(output, (0, 2, 3, 1)) if divide: output = small2large(output) output = smooth(output) else: output = output[0] output = (output * 255).astype('uint8') output = cv2.resize(output, origin_shape) cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output) def paint_img(img, max_step = 40, model_choices = "Default"): Decoder = decoders[model_choices] actor = actors[model_choices] max_step = int(max_step) # imgid = 0 # output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0') # os.makedirs(output_name, exist_ok= True) # img = cv2.imread(args.img, cv2.IMREAD_COLOR) origin_shape = (img.shape[1], img.shape[0]) patch_img = cv2.resize(img, (width * divide, width * divide)) patch_img = large2small(patch_img) patch_img = np.transpose(patch_img, (0, 3, 1, 2)) patch_img = torch.tensor(patch_img).to(device).float() / 255. img = cv2.resize(img, (width, width)) img = img.reshape(1, width, width, 3) img = np.transpose(img, (0, 3, 1, 2)) img = torch.tensor(img).to(device).float() / 255. T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device) coord = torch.zeros([1, 2, width, width]) for i in range(width): for j in range(width): coord[0, 0, i, j] = i / (width - 1.) coord[0, 1, i, j] = j / (width - 1.) coord = coord.to(device) # Coordconv canvas = torch.zeros([1, 3, width, width]).to(device) with torch.no_grad(): if divide != 1: max_step = max_step // 2 for i in range(max_step): stepnum = T * i / max_step actions = actor(torch.cat([canvas, img, stepnum, coord], 1)) canvas, res = decode(actions, canvas, Decoder) for j in range(5): # save_img(res[j], imgid) # imgid += 1 output = res[j].detach().cpu().numpy() # d * d, 3, width, width output = np.transpose(output, (0, 2, 3, 1)) output = output[0] output = (output * 255).astype('uint8') output = cv2.resize(output, origin_shape) yield output if divide != 1: canvas = canvas[0].detach().cpu().numpy() canvas = np.transpose(canvas, (1, 2, 0)) canvas = cv2.resize(canvas, (width * divide, width * divide)) canvas = large2small(canvas) canvas = np.transpose(canvas, (0, 3, 1, 2)) canvas = torch.tensor(canvas).to(device).float() coord = coord.expand(canvas_cnt, 2, width, width) T = T.expand(canvas_cnt, 1, width, width) for i in range(max_step): stepnum = T * i / max_step actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1)) canvas, res = decode(actions, canvas, Decoder) # print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean())) for j in range(5): # save_img(res[j], imgid, True) # imgid += 1 output = res[j].detach().cpu().numpy() # d * d, 3, width, width output = np.transpose(output, (0, 2, 3, 1)) output = small2large(output) output = smooth(output) output = (output * 255).astype('uint8') output = cv2.resize(output, origin_shape) yield output yield output def load_model_if_needed(choice: str): # global Decoder, actor if choice == "Default": actor_path = 'ckpts/actor.pkl' renderer_path = 'ckpts/renderer.pkl' elif choice == "Triangle": actor_path = 'ckpts/actor_triangle.pkl' renderer_path = 'ckpts/triangle.pkl' elif choice == "Round": actor_path = 'ckpts/actor_round.pkl' renderer_path = 'ckpts/round.pkl' else: actor_path = 'ckpts/actor_notrans.pkl' renderer_path = 'ckpts/bezierwotrans.pkl' if choice not in decoders: Decoder = FCN() Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu")) Decoder = Decoder.to(device).eval() decoders[choice] = Decoder if choice not in actors: actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13 actor.load_state_dict(torch.load(actor_path, map_location= "cpu")) actor = actor.to(device).eval() actors[choice] = actor from typing import Generator def wrapper(func): event:Generator = range(0) def inner(*args, **kwargs): nonlocal event val = args[0] if val == "Cancel": args_ = tuple(x for i,x in enumerate(args) if i > 0) event = func(*args_, **kwargs) yield from event else: try: event.close() yield except: pass return inner examples = [ ["image/chaoyue.png"], ["image/degang.png"], ["image/JayChou.png"], ["image/Leslie.png"], ["image/mayun.png"], ] output = gr.Image(label="Painting Result") with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input image") with gr.Row(): step = gr.Slider(20, 100, value= 40, step = 1, label= 'Painting step') with gr.Row(): dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice') with gr.Row(): with gr.Column(): clr_btn = gr.ClearButton([input_image, output], variant= "stop") with gr.Column(): translate_btn = gr.Button(value="Paint", variant="primary") with gr.Column(): output.render() dropdown.select(load_model_if_needed, dropdown) click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\ .then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\ .then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn) clr_btn.click(None, None, cancels=[click_event]) examples = gr.Examples(examples=examples, inputs=[input_image], cache_examples = False) # demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples) demo.queue(default_concurrency_limit= 4) demo.launch(server_name="0.0.0.0", )