import gradio as gr import torch from PIL import Image from torchvision import transforms from utils import normalize_lab, denormalize_lab, pad_image from model import Generator import kornia.color as color device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Generator() model_weights = torch.load('model.pth', map_location=device, weights_only=True) model.load_state_dict(model_weights) model = model.to(device) model.eval() def preprocess(image): image = image.convert('RGB') image = pad_image(image) transform = transforms.Compose([ transforms.ToTensor(), ]) image = transform(image) image = image.to(device) image = color.rgb_to_lab(image) L = image[[0], ...] L, _ = normalize_lab(L, 0) return L.unsqueeze(0) def crop_to_original_size(image, original_size): width, height = original_size return transforms.functional.crop(image, top=0, left=0, height=height, width=width) def predict(image): original_size = image.size L = preprocess(image) with torch.no_grad(): output = model(L) L, ab = denormalize_lab(L, output) output = torch.cat([L, ab], dim=1) output = color.lab_to_rgb(output) output = crop_to_original_size(output, original_size) image = transforms.ToPILImage()(output.squeeze().cpu()) return image iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Photo Colorizer", description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",) iface.launch()