import gradio as gr import torch import torch.nn.functional as F from PIL import Image from pathlib import Path from torchvision import transforms from src.models.catdog_model_resnet import ResnetClassifier from src.utils.aws_s3_services import S3Handler from src.utils.logging_utils import setup_logger from loguru import logger import rootutils # Load environment variables and configure logger setup_logger(Path("./logs") / "gradio_app.log") # Setup root directory root = rootutils.setup_root(__file__, indicator=".project-root") class ImageClassifier: def __init__(self, cfg): self.cfg = cfg self.device = "cuda" if torch.cuda.is_available() else "cpu" self.classes = cfg.labels # Download and load model from S3 logger.info("Downloading model from S3...") s3_handler = S3Handler(bucket_name="deep-bucket-s3") s3_handler.download_folder("checkpoints", "checkpoints") logger.info("Loading model checkpoint...") self.model = ResnetClassifier.load_from_checkpoint( checkpoint_path=cfg.ckpt_path ) self.model = self.model.to(self.device) self.model.eval() # Image transform self.transform = transforms.Compose( [ transforms.Resize((cfg.data.image_size, cfg.data.image_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def predict(self, image): if image is None: return "No image provided.", None # Preprocess the image logger.info("Processing input image...") img_tensor = self.transform(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): output = self.model(img_tensor) probabilities = F.softmax(output, dim=1) predicted_class_idx = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][predicted_class_idx].item() predicted_label = self.classes[predicted_class_idx] logger.info(f"Prediction: {predicted_label} (Confidence: {confidence:.2f})") return predicted_label, confidence def create_gradio_app(cfg): classifier = ImageClassifier(cfg) def classify_image(image): """Gradio interface function.""" predicted_label, confidence = classifier.predict(image) if predicted_label: return f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" return "Error during prediction." # Create Gradio interface with gr.Blocks() as demo: gr.Markdown( """ # Cat vs Dog Classifier Upload an image of a cat or a dog to classify it with confidence. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="pil", image_mode="RGB" ) predict_button = gr.Button("Classify") with gr.Column(): output_text = gr.Textbox(label="Prediction") # Define interaction predict_button.click( fn=classify_image, inputs=[input_image], outputs=[output_text] ) return demo # Hydra config wrapper for launching Gradio app if __name__ == "__main__": import hydra from omegaconf import DictConfig @hydra.main(config_path="configs", config_name="infer", version_base="1.3") def main(cfg: DictConfig): logger.info("Launching Gradio App...") demo = create_gradio_app(cfg) demo.launch(share=True, server_name="0.0.0.0", server_port=7860) main()