|
import argparse |
|
import os |
|
import torch |
|
import mlflow |
|
from data_ingestion import create_dataloaders, test_data_ingestion |
|
from model import Generator, Discriminator, init_weights, test_models |
|
from train import train, test_training |
|
from app import setup_gradio_app |
|
|
|
EXPERIMENT_NAME = "Colorizer_Experiment" |
|
|
|
def setup_mlflow(): |
|
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) |
|
if experiment is None: |
|
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) |
|
else: |
|
experiment_id = experiment.experiment_id |
|
return experiment_id |
|
|
|
def run_pipeline(args): |
|
device = torch.device(args.device) |
|
print(f"Using device: {device}") |
|
|
|
experiment_id = setup_mlflow() |
|
|
|
if args.ingest_data or args.run_all: |
|
print("Starting data ingestion...") |
|
train_loader = create_dataloaders(batch_size=args.batch_size) |
|
if train_loader is None: |
|
print("Data ingestion failed.") |
|
return |
|
else: |
|
train_loader = None |
|
|
|
if args.create_model or args.train or args.run_all: |
|
print("Creating and testing models...") |
|
generator = Generator().to(device) |
|
discriminator = Discriminator().to(device) |
|
generator.apply(init_weights) |
|
discriminator.apply(init_weights) |
|
if not test_models(): |
|
print("Model creation or testing failed.") |
|
return |
|
else: |
|
generator = None |
|
discriminator = None |
|
|
|
if args.train or args.run_all: |
|
print("Starting model training...") |
|
if train_loader is None: |
|
print("Creating dataloader for training...") |
|
train_loader = create_dataloaders(batch_size=args.batch_size) |
|
if train_loader is None: |
|
print("Failed to create dataloader for training.") |
|
return |
|
if generator is None or discriminator is None: |
|
print("Creating models for training...") |
|
generator = Generator().to(device) |
|
discriminator = Discriminator().to(device) |
|
generator.apply(init_weights) |
|
discriminator.apply(init_weights) |
|
run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device) |
|
if run_id: |
|
print(f"Training completed. Run ID: {run_id}") |
|
with open("latest_run_id.txt", "w") as f: |
|
f.write(run_id) |
|
else: |
|
print("Training failed.") |
|
return |
|
|
|
if args.test_training: |
|
print("Testing training process...") |
|
if train_loader is None: |
|
print("Creating dataloader for testing...") |
|
train_loader = create_dataloaders(batch_size=args.batch_size) |
|
if train_loader is None: |
|
print("Failed to create dataloader for testing.") |
|
return |
|
if generator is None or discriminator is None: |
|
print("Creating models for testing...") |
|
generator = Generator().to(device) |
|
discriminator = Discriminator().to(device) |
|
generator.apply(init_weights) |
|
discriminator.apply(init_weights) |
|
if test_training(generator, discriminator, train_loader, device): |
|
print("Training process test passed.") |
|
else: |
|
print("Training process test failed.") |
|
|
|
if args.serve or args.run_all: |
|
print("Setting up Gradio app for serving...") |
|
if not args.run_id: |
|
try: |
|
with open("latest_run_id.txt", "r") as f: |
|
args.run_id = f.read().strip() |
|
except FileNotFoundError: |
|
print("No run ID provided and couldn't find latest_run_id.txt") |
|
return |
|
iface = setup_gradio_app(args.run_id, device) |
|
iface.launch(share=args.share) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Run Colorizer Pipeline") |
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="Device to use (cuda/cpu)") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") |
|
parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train") |
|
parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model for inference") |
|
parser.add_argument("--ingest_data", action="store_true", help="Run data ingestion") |
|
parser.add_argument("--create_model", action="store_true", help="Create and test the model") |
|
parser.add_argument("--train", action="store_true", help="Train the model") |
|
parser.add_argument("--test_training", action="store_true", help="Test the training process") |
|
parser.add_argument("--serve", action="store_true", help="Serve the model using Gradio") |
|
parser.add_argument("--run_all", action="store_true", help="Run all steps") |
|
parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly") |
|
args = parser.parse_args() |
|
|
|
run_pipeline(args) |