### 1. Imports import gradio as gr import os import torch from PIL import Image from model import create_model_alexnet from timeit import default_timer as timer from typing import Tuple, Dict ### 2. Model and transforms preparation ### # Create model_alexnet model_alexnet, transforms = create_model_alexnet( num_classes=2) # Load saved weights model_alexnet.load_state_dict(torch.load(f="cat_dog_classifier.pth", map_location=torch.device("cpu"))) # load to CPU ### 3. Predict function ### # Create predict function def predict(img): # Start the timer start_time = timer() model_alexnet.eval() # Reading the image and size transformation features = Image.open(img) img = transforms(features).unsqueeze(0) with torch.inference_mode(): output = model_alexnet(img) _, predicted = torch.max(output, 1) # Create a prediction label and prediction probability dictionary for each prediction class # This is the required format for Gradio's output parameter pred_labels = 'Cat' if predicted.item() ==1 else 'Dog' # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time return pred_labels, pred_time ### 4. Gradio app ### import gradio as gr # Create title, description and article strings title = "Classification Demo" description = "Cat/Dog classification - Transfer Learning " # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type='filepath'), # what are the inputs? outputs=[gr.Label(label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs #examples=example_list, title=title, description=description,) # Launch the demo! demo.launch()