File size: 4,130 Bytes
02f3f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
import mlflow
import mlflow.pytorch
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from torchvision import transforms
import argparse

from model import Generator

EXPERIMENT_NAME = "Colorizer_Experiment"

def setup_mlflow():
    experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if experiment is None:
        experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
    else:
        experiment_id = experiment.experiment_id
    return experiment_id

def load_model(run_id, device):
    print(f"Loading model from run: {run_id}")
    model_uri = f"runs:/{run_id}/generator_model"
    model = mlflow.pytorch.load_model(model_uri, map_location=device)
    return model

# Configuration variables
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RUN_ID = "your_run_id_here"  # Replace with the actual run ID
IMAGE_PATH = "path/to/your/image.jpg"  # Replace with the path to your input image
SAVE_MODEL = False
SERVE_MODEL = False
SERVE_PORT = 5000

def preprocess_image(image_path):
    img = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    img_tensor = transform(img)
    lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
    L = lab_img[:,:,0]
    L = (L - 50) / 50
    L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
    return L

def postprocess_output(L, ab):
    L = L.squeeze().cpu().numpy()
    ab = ab.squeeze().cpu().numpy()
    L = (L + 1.) * 50.
    ab = ab * 128.
    Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
    rgb_img = lab2rgb(Lab)
    return (rgb_img * 255).astype(np.uint8)

def colorize_image(model, image_path, device):
    L = preprocess_image(image_path).to(device)
    with torch.no_grad():
        ab = model(L)
    colorized = postprocess_output(L, ab)
    return colorized

def save_model(model, run_id):
    with mlflow.start_run(run_id=run_id):
        # Log the model
        mlflow.pytorch.log_model(model, "model")
        
        # Register the model
        model_uri = f"runs:/{run_id}/model"
        mlflow.register_model(model_uri, "colorizer_model")
        
        print(f"Model saved and registered with run_id: {run_id}")

def serve_model(run_id, port=5000):
    model_uri = f"runs:/{run_id}/model"
    mlflow.pytorch.serve(model_uri, port=port)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colorizer Inference")
    parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model")
    parser.add_argument("--image_path", type=str, required=True, help="Path to the input grayscale image")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to use for inference (cuda/cpu)")
    args = parser.parse_args()

    device = torch.device(args.device)
    print(f"Using device: {device}")

    # If run_id is not provided, try to load it from the file
    if not args.run_id:
        try:
            with open("latest_run_id.txt", "r") as f:
                args.run_id = f.read().strip()
        except FileNotFoundError:
            print("No run ID provided and couldn't find latest_run_id.txt")
            exit(1)

    experiment_id = setup_mlflow()
    with mlflow.start_run(experiment_id=experiment_id, run_name="inference_run"):
        try:
            model = load_model(args.run_id, device)
            colorized = colorize_image(model, args.image_path, device)
            output_path = f"colorized_{os.path.basename(args.image_path)}"
            Image.fromarray(colorized).save(output_path)
            print(f"Colorized image saved as: {output_path}")
            
            mlflow.log_artifact(output_path)
            mlflow.log_param("input_image", args.image_path)
            mlflow.log_param("model_run_id", args.run_id)
        except Exception as e:
            print(f"Error during inference: {str(e)}")
            mlflow.log_param("error", str(e))
        finally:
            mlflow.end_run()