|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
from skimage.color import rgb2lab, lab2rgb |
|
import os |
|
from torch import nn |
|
|
|
|
|
class UNetBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False): |
|
super(UNetBlock, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \ |
|
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False) |
|
self.bn = nn.BatchNorm2d(out_channels) if bn else None |
|
self.dropout = nn.Dropout(0.5) if dropout else None |
|
self.down = down |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
if self.bn: |
|
x = self.bn(x) |
|
if self.dropout: |
|
x = self.dropout(x) |
|
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x) |
|
|
|
class Generator(nn.Module): |
|
def __init__(self): |
|
super(Generator, self).__init__() |
|
self.down1 = UNetBlock(1, 64, bn=False) |
|
self.down2 = UNetBlock(64, 128) |
|
self.down3 = UNetBlock(128, 256) |
|
self.down4 = UNetBlock(256, 512) |
|
self.down5 = UNetBlock(512, 512) |
|
self.down6 = UNetBlock(512, 512) |
|
self.down7 = UNetBlock(512, 512) |
|
self.down8 = UNetBlock(512, 512, bn=False) |
|
|
|
self.up1 = UNetBlock(512, 512, down=False, dropout=True) |
|
self.up2 = UNetBlock(1024, 512, down=False, dropout=True) |
|
self.up3 = UNetBlock(1024, 512, down=False, dropout=True) |
|
self.up4 = UNetBlock(1024, 512, down=False) |
|
self.up5 = UNetBlock(1024, 256, down=False) |
|
self.up6 = UNetBlock(512, 128, down=False) |
|
self.up7 = UNetBlock(256, 64, down=False) |
|
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1) |
|
|
|
def forward(self, x): |
|
d1 = self.down1(x) |
|
d2 = self.down2(d1) |
|
d3 = self.down3(d2) |
|
d4 = self.down4(d3) |
|
d5 = self.down5(d4) |
|
d6 = self.down6(d5) |
|
d7 = self.down7(d6) |
|
d8 = self.down8(d7) |
|
|
|
u1 = self.up1(d8) |
|
u2 = self.up2(torch.cat([u1, d7], 1)) |
|
u3 = self.up3(torch.cat([u2, d6], 1)) |
|
u4 = self.up4(torch.cat([u3, d5], 1)) |
|
u5 = self.up5(torch.cat([u4, d4], 1)) |
|
u6 = self.up6(torch.cat([u5, d3], 1)) |
|
u7 = self.up7(torch.cat([u6, d2], 1)) |
|
return torch.tanh(self.up8(torch.cat([u7, d1], 1))) |
|
|
|
|
|
def load_checkpoint(filename, generator, map_location): |
|
if os.path.isfile(filename): |
|
print(f"Loading checkpoint '{filename}'") |
|
checkpoint = torch.load(filename, map_location=map_location) |
|
generator.load_state_dict(checkpoint['generator_state_dict']) |
|
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})") |
|
else: |
|
print(f"No checkpoint found at '{filename}'") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
generator = Generator().to(device) |
|
checkpoint_path = "checkpoints/latest_checkpoint.pth.tar" |
|
load_checkpoint(checkpoint_path, generator, map_location=device) |
|
generator.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.Grayscale(num_output_channels=1), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
def colorize_image(input_image): |
|
try: |
|
original_size = input_image.size |
|
input_image = transform(input_image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
output = generator(input_image) |
|
output = output.squeeze(0).cpu().numpy() |
|
L = input_image.squeeze(0).cpu().numpy() |
|
L = (L + 1.) * 50. |
|
ab = output * 128. |
|
Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0) |
|
rgb_image = lab2rgb(Lab) |
|
rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8)) |
|
rgb_image = rgb_image.resize(original_size, Image.LANCZOS) |
|
return rgb_image |
|
except Exception as e: |
|
print(f"Error in colorize_image: {str(e)}") |
|
return None |
|
|
|
|
|
iface = gr.Interface( |
|
fn=colorize_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Image Colorizer", |
|
description="Upload a grayscale image to colorize it." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |