File size: 3,769 Bytes
ea271d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from torchvision import transforms
from src.models.catdog_model_resnet import ResnetClassifier
from src.utils.aws_s3_services import S3Handler
from src.utils.logging_utils import setup_logger
from loguru import logger
import rootutils

# Load environment variables and configure logger
setup_logger(Path("./logs") / "gradio_app.log")
# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")


class ImageClassifier:
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.classes = cfg.labels

        # Download and load model from S3
        logger.info("Downloading model from S3...")
        s3_handler = S3Handler(bucket_name="deep-bucket-s3")
        s3_handler.download_folder("checkpoints", "checkpoints")

        logger.info("Loading model checkpoint...")
        self.model = ResnetClassifier.load_from_checkpoint(
            checkpoint_path=cfg.ckpt_path
        )
        self.model = self.model.to(self.device)
        self.model.eval()

        # Image transform
        self.transform = transforms.Compose(
            [
                transforms.Resize((cfg.data.image_size, cfg.data.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def predict(self, image):
        if image is None:
            return "No image provided.", None

        # Preprocess the image
        logger.info("Processing input image...")
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # Inference
        with torch.no_grad():
            output = self.model(img_tensor)
            probabilities = F.softmax(output, dim=1)
            predicted_class_idx = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][predicted_class_idx].item()

        predicted_label = self.classes[predicted_class_idx]
        logger.info(f"Prediction: {predicted_label} (Confidence: {confidence:.2f})")
        return predicted_label, confidence


def create_gradio_app(cfg):
    classifier = ImageClassifier(cfg)

    def classify_image(image):
        """Gradio interface function."""
        predicted_label, confidence = classifier.predict(image)
        if predicted_label:
            return f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
        return "Error during prediction."

    # Create Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # Cat vs Dog Classifier
            Upload an image of a cat or a dog to classify it with confidence.
            """
        )

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Input Image", type="pil", image_mode="RGB"
                )
                predict_button = gr.Button("Classify")
            with gr.Column():
                output_text = gr.Textbox(label="Prediction")

        # Define interaction
        predict_button.click(
            fn=classify_image, inputs=[input_image], outputs=[output_text]
        )

    return demo


# Hydra config wrapper for launching Gradio app
if __name__ == "__main__":
    import hydra
    from omegaconf import DictConfig

    @hydra.main(config_path="configs", config_name="infer", version_base="1.3")
    def main(cfg: DictConfig):
        logger.info("Launching Gradio App...")
        demo = create_gradio_app(cfg)
        demo.launch(share=True, server_name="0.0.0.0", server_port=7860)

    main()