import utils import torch import gradio as gr import numpy as np from PIL import Image from network import ImageTransformNet_dpws from torch.autograd import Variable from torchvision import transforms with gr.Blocks() as demo: with gr.Row(): gr.HTML('

스타일 변환기

') with gr.Row(): with gr.Column(): style_radio = gr.Radio(['La muse', 'Mosaic', 'Starry Night Crop', 'Wave Crop'], label='원하는 스타일 선택!') image_input = gr.Image(label='콘텐츠 이미지') convert_button = gr.Button('변환!') with gr.Column(): result_image = gr.Image(label='결과 이미지') def transform_image(style, img): dtype = torch.FloatTensor # content image img_transform_512 = transforms.Compose([ # transforms.Scale(512), # scale shortest side to image_size transforms.Resize(512), # scale shortest side to image_size # transforms.CenterCrop(512), # crop center image_size out transforms.ToTensor(), # turn image from [0-255] to [0-1] utils.normalize_tensor_transform() # normalize with ImageNet values ]) content = Image.fromarray(img) content = img_transform_512(content) content = content.unsqueeze(0) # content = Variable(content).type(dtype) content = Variable(content.repeat(1, 1, 1, 1), requires_grad=False).type(dtype) # load style model model_folder_name = '_'.join(style.lower().split()) model_path = 'models/' + model_folder_name + '/compressed.model' checkpoint_lw = torch.load(model_path) style_model = ImageTransformNet_dpws().type(dtype) style_model.load_state_dict((checkpoint_lw)) # process input image stylized = style_model(content).cpu() utils.save_image('results.jpg', stylized.data[0]) return 'results.jpg' convert_button.click( transform_image, inputs=[style_radio, image_input], outputs=[result_image], ) demo.launch()