|
|
|
import os |
|
|
|
from PIL import Image |
|
from torchvision import transforms as T |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import DataLoader |
|
from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet |
|
import torch.nn as nn |
|
import torch |
|
import gradio as gr |
|
|
|
from collections import OrderedDict |
|
import glob |
|
|
|
|
|
|
|
|
|
def pred_pipeline(img, transforms): |
|
orig_shape = img.shape |
|
input = transforms(img) |
|
input = input.unsqueeze(0) |
|
output_syn = real2sim(input) |
|
output_real = sim2real(output_syn) |
|
out_img_syn = make_grid(output_syn, |
|
nrow=1, normalize=True) |
|
out_img_real = make_grid(output_real, |
|
nrow=1, normalize=True) |
|
|
|
|
|
|
|
out_transform = Compose([ |
|
T.Resize(orig_shape[:2]), |
|
T.ToPILImage() |
|
]) |
|
return out_transform(out_img_syn), out_transform(out_img_real) |
|
|
|
|
|
|
|
|
|
n_channels = 3 |
|
image_size = 512 |
|
input_shape = (image_size, image_size) |
|
|
|
transform = Compose([ |
|
T.ToPILImage(), |
|
T.Resize(input_shape), |
|
ToTensor(), |
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
]) |
|
|
|
|
|
sim2real = GeneratorResNet.from_pretrained('Chris1/sim2real-512', input_shape=(n_channels, image_size, image_size), |
|
num_residual_blocks=9) |
|
real2sim = GeneratorResNet.from_pretrained('Chris1/real2sim-512', input_shape=(n_channels, image_size, image_size), |
|
num_residual_blocks=9) |
|
|
|
gr.Interface(lambda image: pred_pipeline(image, transform), |
|
inputs=gr.inputs.Image( label='input synthetic image'), |
|
outputs=[ |
|
gr.outputs.Image( type="pil",label='GAN real2sim prediction: style transfer of the input to the synthetic world '), |
|
gr.outputs.Image( type="pil",label='GAN sim2real prediction: translation to real of the above prediction') |
|
], |
|
title = "Cityscapes (real) to GTA5(simulated) translation", |
|
examples = [ |
|
[example] for example in glob.glob('./samples/*.png') |
|
])\ |
|
.launch() |
|
|
|
|
|
|