Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from utils import normalize_lab, denormalize_lab, pad_image | |
from model import Generator | |
import kornia.color as color | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = Generator() | |
model_weights = torch.load('model.pth', map_location=device, weights_only=True) | |
model.load_state_dict(model_weights) | |
model = model.to(device) | |
model.eval() | |
def preprocess(image): | |
image = image.convert('RGB') | |
image = pad_image(image) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
image = transform(image) | |
image = image.to(device) | |
image = color.rgb_to_lab(image) | |
L = image[[0], ...] | |
L, _ = normalize_lab(L, 0) | |
return L.unsqueeze(0) | |
def crop_to_original_size(image, original_size): | |
width, height = original_size | |
return transforms.functional.crop(image, top=0, left=0, height=height, width=width) | |
def predict(image): | |
original_size = image.size | |
L = preprocess(image) | |
with torch.no_grad(): | |
output = model(L) | |
L, ab = denormalize_lab(L, output) | |
output = torch.cat([L, ab], dim=1) | |
output = color.lab_to_rgb(output) | |
output = crop_to_original_size(output, original_size) | |
image = transforms.ToPILImage()(output.squeeze().cpu()) | |
return image | |
iface = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="Photo Colorizer", | |
description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",) | |
iface.launch() |