Photo_Colorizer / app.py
matikosowy's picture
description and title
1f5788e
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from utils import normalize_lab, denormalize_lab, pad_image
from model import Generator
import kornia.color as color
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Generator()
model_weights = torch.load('model.pth', map_location=device, weights_only=True)
model.load_state_dict(model_weights)
model = model.to(device)
model.eval()
def preprocess(image):
image = image.convert('RGB')
image = pad_image(image)
transform = transforms.Compose([
transforms.ToTensor(),
])
image = transform(image)
image = image.to(device)
image = color.rgb_to_lab(image)
L = image[[0], ...]
L, _ = normalize_lab(L, 0)
return L.unsqueeze(0)
def crop_to_original_size(image, original_size):
width, height = original_size
return transforms.functional.crop(image, top=0, left=0, height=height, width=width)
def predict(image):
original_size = image.size
L = preprocess(image)
with torch.no_grad():
output = model(L)
L, ab = denormalize_lab(L, output)
output = torch.cat([L, ab], dim=1)
output = color.lab_to_rgb(output)
output = crop_to_original_size(output, original_size)
image = transforms.ToPILImage()(output.squeeze().cpu())
return image
iface = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Photo Colorizer",
description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",)
iface.launch()