Spaces:
Runtime error
Runtime error
import torch | |
from torchvision.utils import make_grid | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
from torch import nn, optim | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from torch.utils.data import DataLoader, Dataset | |
from huggingface_hub import hf_hub_download | |
import requests | |
import gradio as gr | |
class Upsample(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=True): | |
super(Upsample, self).__init__() | |
self.dropout = dropout | |
self.block = nn.Sequential( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d), | |
nn.InstanceNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
self.dropout_layer = nn.Dropout2d(0.5) | |
def forward(self, x, shortcut=None): | |
x = self.block(x) | |
if self.dropout: | |
x = self.dropout_layer(x) | |
if shortcut is not None: | |
x = torch.cat([x, shortcut], dim=1) | |
return x | |
class Downsample(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, apply_instancenorm=True): | |
super(Downsample, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d) | |
self.norm = nn.InstanceNorm2d(out_channels) | |
self.relu = nn.LeakyReLU(0.2, inplace=True) | |
self.apply_norm = apply_instancenorm | |
def forward(self, x): | |
x = self.conv(x) | |
if self.apply_norm: | |
x = self.norm(x) | |
x = self.relu(x) | |
return x | |
class CycleGAN_Unet_Generator(nn.Module): | |
def __init__(self, filter=64): | |
super(CycleGAN_Unet_Generator, self).__init__() | |
self.downsamples = nn.ModuleList([ | |
Downsample(3, filter, kernel_size=4, apply_instancenorm=False), # (b, filter, 128, 128) | |
Downsample(filter, filter * 2), # (b, filter * 2, 64, 64) | |
Downsample(filter * 2, filter * 4), # (b, filter * 4, 32, 32) | |
Downsample(filter * 4, filter * 8), # (b, filter * 8, 16, 16) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 8, 8) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 4, 4) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 2, 2) | |
]) | |
self.upsamples = nn.ModuleList([ | |
Upsample(filter * 8, filter * 8), | |
Upsample(filter * 16, filter * 8), | |
Upsample(filter * 16, filter * 8), | |
Upsample(filter * 16, filter * 4, dropout=False), | |
Upsample(filter * 8, filter * 2, dropout=False), | |
Upsample(filter * 4, filter, dropout=False) | |
]) | |
self.last = nn.Sequential( | |
nn.ConvTranspose2d(filter * 2, 3, kernel_size=4, stride=2, padding=1), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
skips = [] | |
for l in self.downsamples: | |
x = l(x) | |
skips.append(x) | |
skips = reversed(skips[:-1]) | |
for l, s in zip(self.upsamples, skips): | |
x = l(x, s) | |
out = self.last(x) | |
return out | |
class ImageTransform: | |
def __init__(self, img_size=256): | |
self.transform = { | |
'train': transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]), | |
'test': transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
])} | |
def __call__(self, img, phase='train'): | |
img = self.transform[phase](img) | |
return img | |
path = hf_hub_download('huggan/NeonGAN', 'model.bin') | |
model_gen_n = torch.load(path, map_location=torch.device('cpu')) | |
transform = ImageTransform(img_size=256) | |
inputs = [ | |
gr.inputs.Image(type="pil", label="Original Image") | |
] | |
outputs = [ | |
gr.outputs.Image(type="pil", label="Neon Image") | |
] | |
def get_output_image(img): | |
img = transform(img, phase='test') | |
gen_img = model_gen_n(img.unsqueeze(0))[0] | |
# Reverse Normalization | |
gen_img = gen_img * 0.5 + 0.5 | |
gen_img = gen_img * 255 | |
gen_img = gen_img.detach().cpu().numpy().astype(np.uint8) | |
gen_img = np.transpose(gen_img, [1,2,0]) | |
gen_img = Image.fromarray(gen_img) | |
return gen_img | |
gr.Interface( | |
get_output_image, | |
inputs, | |
outputs, | |
examples = examples, | |
title=title, | |
description=description, | |
theme="huggingface", | |
).launch(enable_queue=True) | |