pix2pixcolorizer / colorizer_pipeline.py
Rohil Bansal
New training
09ae4e4
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import IterableDataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import mlflow
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from datasets import load_dataset
from PIL import Image
from itertools import islice
import traceback
# MLflow setup
EXPERIMENT_NAME = "Colorizer_Experiment"
def setup_mlflow():
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
else:
experiment_id = experiment.experiment_id
return experiment_id
# Data ingestion
class ColorizeIterableDataset(IterableDataset):
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
def __iter__(self):
for item in self.dataset:
try:
img = item['image']
if img.mode != 'RGB':
img = img.convert('RGB')
if self.transform:
img = self.transform(img)
# Add shape check after transform
if img.shape != (3, 256, 256):
print(f"Unexpected image shape after transform: {img.shape}")
continue
lab = rgb2lab(img.permute(1, 2, 0).numpy())
# Add shape check after rgb2lab conversion
if lab.shape != (256, 256, 3):
print(f"Unexpected lab shape: {lab.shape}")
continue
l_chan = lab[:, :, 0]
l_chan = (l_chan - 50) / 50
ab_chan = lab[:, :, 1:]
ab_chan = ab_chan / 128
yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
except Exception as e:
print(f"Error processing image: {str(e)}")
continue
def create_dataloaders(batch_size=32):
try:
print("Loading ImageNet dataset in streaming mode...")
dataset = load_dataset("imagenet-1k", split="train", streaming=True)
print("Dataset loaded in streaming mode.")
print("Creating custom dataset...")
transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize all images to 256x256
transforms.ToTensor()
])
train_dataset = ColorizeIterableDataset(dataset, transform=transform)
print("Custom dataset created.")
print("Creating dataloader...")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
print("Dataloader created.")
return train_dataloader
except Exception as e:
print(f"Error in create_dataloaders: {str(e)}")
return None
# Model definition
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
super(UNetBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels) if bn else None
self.dropout = nn.Dropout(0.5) if dropout else None
self.down = down
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.dropout:
x = self.dropout(x)
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = UNetBlock(1, 64, bn=False)
self.down2 = UNetBlock(64, 128)
self.down3 = UNetBlock(128, 256)
self.down4 = UNetBlock(256, 512)
self.down5 = UNetBlock(512, 512)
self.down6 = UNetBlock(512, 512)
self.down7 = UNetBlock(512, 512)
self.down8 = UNetBlock(512, 512, bn=False)
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
self.up4 = UNetBlock(1024, 512, down=False)
self.up5 = UNetBlock(1024, 256, down=False)
self.up6 = UNetBlock(512, 128, down=False)
self.up7 = UNetBlock(256, 64, down=False)
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8)
u2 = self.up2(torch.cat([u1, d7], 1))
u3 = self.up3(torch.cat([u2, d6], 1))
u4 = self.up4(torch.cat([u3, d5], 1))
u5 = self.up5(torch.cat([u4, d4], 1))
u6 = self.up6(torch.cat([u5, d3], 1))
u7 = self.up7(torch.cat([u6, d2], 1))
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, x):
return self.model(x)
def init_weights(model):
classname = model.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(model.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(model.weight.data, 1.0, 0.02)
nn.init.constant_(model.bias.data, 0)
# Training utilities
def lab_to_rgb(L, ab):
L = (L + 1.) * 50.
ab = ab * 128.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def visualize_results(epoch, generator, train_loader, device):
generator.eval()
with torch.no_grad():
try:
for inputs, real_AB in train_loader:
print(f"Input shape: {inputs.shape}, real_AB shape: {real_AB.shape}")
# Ensure inputs have the correct shape (B, 1, H, W)
if inputs.shape[1] != 1:
inputs = inputs.unsqueeze(1)
inputs, real_AB = inputs.to(device), real_AB.to(device)
fake_AB = generator(inputs)
print(f"fake_AB shape: {fake_AB.shape}")
# Ensure fake_AB and real_AB have the correct shape (B, 2, H, W)
if fake_AB.shape[1] != 2:
fake_AB = fake_AB.view(fake_AB.shape[0], 2, fake_AB.shape[2], fake_AB.shape[3])
if real_AB.shape[1] != 2:
real_AB = real_AB.view(real_AB.shape[0], 2, real_AB.shape[2], real_AB.shape[3])
fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu())
real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu())
print(f"fake_rgb shape: {fake_rgb.shape}, real_rgb shape: {real_rgb.shape}")
concatenated = np.concatenate([real_rgb, fake_rgb], axis=2) # Changed axis from 3 to 2
print(f"Concatenated shape: {concatenated.shape}")
img_grid = make_grid(torch.from_numpy(concatenated).permute(0, 3, 1, 2), normalize=True, nrow=4)
plt.figure(figsize=(15, 15))
plt.imshow(img_grid.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title(f'Epoch {epoch}')
plt.savefig(f'results/epoch_{epoch}.png')
mlflow.log_artifact(f'results/epoch_{epoch}.png')
plt.close()
break
except Exception as e:
print(f"Error in visualize_results: {str(e)}")
traceback.print_exc()
generator.train()
def save_checkpoint(state, filename="checkpoint.pth.tar"):
# Only save the necessary state
save_state = {
'epoch': state['epoch'],
'generator_state_dict': state['generator_state_dict'],
'discriminator_state_dict': state['discriminator_state_dict'],
'optimizerG_state_dict': state['optimizerG_state_dict'],
'optimizerD_state_dict': state['optimizerD_state_dict'],
}
torch.save(save_state, filename)
mlflow.log_artifact(filename)
def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD, device):
if os.path.isfile(filename):
print(f"Loading checkpoint '{filename}'")
# Use weights_only=True for safer loading
checkpoint = torch.load(filename, map_location=device, weights_only=True)
start_epoch = checkpoint['epoch'] + 1
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
return start_epoch
else:
print(f"No checkpoint found at '{filename}'")
return 0
# Global variables
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs("results", exist_ok=True)
# Training function
def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD, device)
experiment_id = setup_mlflow()
with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
try:
for epoch in range(start_epoch, num_epochs):
generator.train()
discriminator.train()
num_iterations = 2000
pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
for i, (real_L, real_AB) in pbar:
# Add shape check
if real_L.shape[1:] != (1, 256, 256) or real_AB.shape[1:] != (2, 256, 256):
print(f"Unexpected tensor shapes: real_L {real_L.shape}, real_AB {real_AB.shape}")
continue
real_L, real_AB = real_L.to(device), real_AB.to(device)
batch_size = real_L.size(0)
# Train Discriminator
optimizerD.zero_grad()
fake_AB = generator(real_L)
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
real_LAB = torch.cat([real_L, real_AB], dim=1)
pred_fake = discriminator(fake_LAB.detach())
loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
pred_real = discriminator(real_LAB)
loss_D_real = criterion(pred_real, torch.ones_like(pred_real))
loss_D = (loss_D_fake + loss_D_real) * 0.5
loss_D.backward()
optimizerD.step()
# Train Generator
optimizerG.zero_grad()
fake_AB = generator(real_L)
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
pred_fake = discriminator(fake_LAB)
loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake))
loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight
loss_G = loss_G_GAN + loss_G_L1
loss_G.backward()
optimizerG.step()
pbar.set_postfix({
'D_loss': loss_D.item(),
'G_loss': loss_G.item(),
'G_L1': loss_G_L1.item()
})
mlflow.log_metrics({
"D_loss": loss_D.item(),
"G_loss": loss_G.item(),
"G_L1_loss": loss_G_L1.item()
}, step=epoch * num_iterations + i)
visualize_results(epoch, generator, train_loader, device)
checkpoint = {
'epoch': epoch,
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'optimizerG_state_dict': optimizerG.state_dict(),
'optimizerD_state_dict': optimizerD.state_dict(),
}
save_checkpoint(checkpoint, filename=checkpoint_path)
print("Training completed successfully.")
mlflow.pytorch.log_model(generator, "generator_model")
model_uri = f"runs:/{run.info.run_id}/generator_model"
mlflow.register_model(model_uri, "colorizer_generator")
return run.info.run_id
except Exception as e:
print(f"Error during training: {str(e)}")
mlflow.log_param("error", str(e))
return None
# Main execution
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
try:
batch_size = 32
num_epochs = 50
train_loader = create_dataloaders(batch_size=batch_size)
if train_loader is None:
raise Exception("Failed to create dataloader")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(init_weights)
discriminator.apply(init_weights)
run_id = train(generator, discriminator, train_loader, num_epochs=num_epochs, device=device)
if run_id:
print(f"Training completed successfully. Run ID: {run_id}")
with open("latest_run_id.txt", "w") as f:
f.write(run_id)
else:
print("Training failed!")
except Exception as e:
print(f"Critical error in main execution: {str(e)}")