|
|
|
import os |
|
|
|
from PIL import Image |
|
from torchvision import transforms as T |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import DataLoader |
|
from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet |
|
import torch.nn as nn |
|
import torch |
|
import gradio as gr |
|
|
|
from collections import OrderedDict |
|
import glob |
|
|
|
|
|
|
|
|
|
def pred_pipeline(img, transforms): |
|
orig_shape = img.shape |
|
input = transforms(img) |
|
input = input.unsqueeze(0) |
|
output = model(input) |
|
|
|
out_img = make_grid(output, |
|
nrow=1, normalize=True) |
|
out_transform = Compose([ |
|
T.Resize(orig_shape[:2]), |
|
T.ToPILImage() |
|
]) |
|
return out_transform(out_img) |
|
|
|
|
|
|
|
|
|
n_channels = 3 |
|
image_size = 512 |
|
input_shape = (image_size, image_size) |
|
|
|
transform = Compose([ |
|
T.ToPILImage(), |
|
T.Resize(input_shape), |
|
ToTensor(), |
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
]) |
|
|
|
|
|
model = GeneratorResNet.from_pretrained('Chris1/sim2real', input_shape=(n_channels, image_size, image_size), |
|
num_residual_blocks=9) |
|
|
|
gr.Interface(lambda image: pred_pipeline(image, transform), |
|
inputs=gr.inputs.Image( label='input synthetic image'), |
|
outputs=gr.outputs.Image( type="pil",label='style transfer to the real world'), |
|
title = "GTA5(simulated) to Cityscapes (real) translation", |
|
examples = [ |
|
[example] for example in glob.glob('./samples/*.png') |
|
])\ |
|
.launch() |
|
|
|
|
|
|
|
|
|
|
|
|