import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.utils import make_grid import mlflow import matplotlib.pyplot as plt from tqdm import tqdm import numpy as np from skimage.color import lab2rgb, rgb2lab import argparse from itertools import islice from PIL import Image import torchvision.transforms as transforms from data_ingestion import ColorizeIterableDataset, create_dataloaders from model import Generator, Discriminator, init_weights EXPERIMENT_NAME = "Colorizer_Experiment" def setup_mlflow(): experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) if experiment is None: experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) else: experiment_id = experiment.experiment_id return experiment_id def lab_to_rgb(L, ab): """Convert L and ab channels to RGB image""" L = (L + 1.) * 50. ab = ab * 128. Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() rgb_imgs = [] for img in Lab: img_rgb = lab2rgb(img) rgb_imgs.append(img_rgb) return np.stack(rgb_imgs, axis=0) def preprocess_image(image_path): img = Image.open(image_path).convert('RGB') img = img.resize((256, 256)) # Resize to a consistent size img_lab = rgb2lab(img) img_lab = (img_lab + [0, 128, 128]) / [100, 255, 255] # Normalize LAB values return img_lab[:,:,0], img_lab[:,:,1:] def visualize_results(epoch, generator, train_loader, device): generator.eval() with torch.no_grad(): for inputs, real_AB in train_loader: inputs, real_AB = inputs.to(device), real_AB.to(device) fake_AB = generator(inputs) fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu()) real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu()) img_grid = make_grid(torch.from_numpy(np.concatenate([real_rgb, fake_rgb], axis=3)).permute(0, 3, 1, 2), normalize=True, nrow=4) plt.figure(figsize=(15, 15)) plt.imshow(img_grid.permute(1, 2, 0).cpu()) plt.axis('off') plt.title(f'Epoch {epoch}') plt.savefig(f'results/epoch_{epoch}.png') mlflow.log_artifact(f'results/epoch_{epoch}.png') plt.close() break # Only visualize one batch generator.train() def save_checkpoint(state, filename="checkpoint.pth.tar"): torch.save(state, filename) mlflow.log_artifact(filename) def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD): if os.path.isfile(filename): print(f"Loading checkpoint '{filename}'") checkpoint = torch.load(filename) start_epoch = checkpoint['epoch'] + 1 generator.load_state_dict(checkpoint['generator_state_dict']) discriminator.load_state_dict(checkpoint['discriminator_state_dict']) optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})") return start_epoch else: print(f"No checkpoint found at '{filename}'") return 0 def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5): criterion = nn.BCEWithLogitsLoss() l1_loss = nn.L1Loss() optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs("results", exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar") start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD) experiment_id = setup_mlflow() with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run: try: for epoch in range(start_epoch, num_epochs): generator.train() discriminator.train() # Use a fixed number of iterations per epoch num_iterations = 1000 pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}") for i, (real_L, real_AB) in pbar: real_L, real_AB = real_L.to(device), real_AB.to(device) batch_size = real_L.size(0) # Train Discriminator optimizerD.zero_grad() fake_AB = generator(real_L) fake_LAB = torch.cat([real_L, fake_AB], dim=1) real_LAB = torch.cat([real_L, real_AB], dim=1) pred_fake = discriminator(fake_LAB.detach()) loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake)) pred_real = discriminator(real_LAB) loss_D_real = criterion(pred_real, torch.ones_like(pred_real)) loss_D = (loss_D_fake + loss_D_real) * 0.5 loss_D.backward() optimizerD.step() # Train Generator optimizerG.zero_grad() fake_AB = generator(real_L) fake_LAB = torch.cat([real_L, fake_AB], dim=1) pred_fake = discriminator(fake_LAB) loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake)) loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight loss_G = loss_G_GAN + loss_G_L1 loss_G.backward() optimizerG.step() pbar.set_postfix({ 'D_loss': loss_D.item(), 'G_loss': loss_G.item(), 'G_L1': loss_G_L1.item() }) mlflow.log_metrics({ "D_loss": loss_D.item(), "G_loss": loss_G.item(), "G_L1_loss": loss_G_L1.item() }, step=epoch * num_iterations + i) visualize_results(epoch, generator, train_loader, device) checkpoint = { 'epoch': epoch, 'generator_state_dict': generator.state_dict(), 'discriminator_state_dict': discriminator.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), } save_checkpoint(checkpoint, filename=checkpoint_path) print("Training completed successfully.") # Log the generator model mlflow.pytorch.log_model(generator, "generator_model") # Register the model model_uri = f"runs:/{run.info.run_id}/generator_model" mlflow.register_model(model_uri, "colorizer_generator") return run.info.run_id except Exception as e: print(f"Error during training: {str(e)}") mlflow.log_param("error", str(e)) return None def test_training(generator, discriminator, train_loader, device): print("Testing training process...") try: train(generator, discriminator, train_loader, num_epochs=1, device=device) print("Training process test passed.") return True except Exception as e: print(f"Training process test failed: {str(e)}") return False if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train Colorizer model") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for training (cuda/cpu)") parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train") parser.add_argument("--test", action="store_true", help="Run in test mode") args = parser.parse_args() device = torch.device(args.device) print(f"Using device: {device}") try: train_loader = create_dataloaders(batch_size=args.batch_size) generator = Generator().to(device) discriminator = Discriminator().to(device) generator.apply(init_weights) discriminator.apply(init_weights) if args.test: if test_training(generator, discriminator, train_loader, device): print("All tests passed.") else: print("Tests failed.") else: run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device) if run_id: print(f"Training completed. Run ID: {run_id}") # Save the run ID to a file for easy access by the inference script with open("latest_run_id.txt", "w") as f: f.write(run_id) else: print("Training failed.") except Exception as e: print(f"Critical error in main execution: {str(e)}")