import os import torch import gradio as gr import numpy as np import torchvision.transforms as transforms from torch.autograd import Variable from network.Transformer import Transformer LOAD_SIZE = 1280 STYLE = "shinkai_makoto" MODEL_PATH = "models" COLOUR_MODEL = "RGB" model = Transformer() model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f"{STYLE}.pth"))) model.eval() disable_gpu = True def inference(img): # load image input_image = img.convert(COLOUR_MODEL) input_image = np.asarray(input_image) # RGB -> BGR input_image = input_image[:, :, [2, 1, 0]] input_image = transforms.ToTensor()(input_image).unsqueeze(0) # preprocess, (-1, 1) input_image = -1 + 2 * input_image if disable_gpu: input_image = Variable(input_image).float() else: input_image = Variable(input_image).cuda() # forward output_image = model(input_image) output_image = output_image[0] # BGR -> RGB output_image = output_image[[2, 1, 0], :, :] output_image = output_image.data.cpu().float() * 0.5 + 0.5 return transforms.ToPILImage()(output_image) title = "Anime Background GAN" description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao." article = "

CartoonGAN from Chen et.al

Github Repo

Original Implementation from Yijunmaverick

visitor badge

" examples = [ ["examples/garden_in.jpg"], ["examples/library_in.jpg"], ] gr.Interface( fn=inference, inputs=[gr.inputs.Image(type="pil")], outputs=gr.outputs.Image(type="pil"), title=title, description=description, article=article, examples=examples, allow_flagging=False, allow_screenshot=False, enable_queue=True, ).launch()