demo_B21_AIML / app.py
Ramendra's picture
Update app.py
a4ed31e
### 1. Imports
import gradio as gr
import os
import torch
from PIL import Image
from model import create_model_alexnet
from timeit import default_timer as timer
from typing import Tuple, Dict
### 2. Model and transforms preparation ###
# Create model_alexnet
model_alexnet, transforms = create_model_alexnet( num_classes=2)
# Load saved weights
model_alexnet.load_state_dict(torch.load(f="cat_dog_classifier.pth", map_location=torch.device("cpu"))) # load to CPU
### 3. Predict function ###
# Create predict function
def predict(img):
# Start the timer
start_time = timer()
model_alexnet.eval()
# Reading the image and size transformation
features = Image.open(img)
img = transforms(features).unsqueeze(0)
with torch.inference_mode():
output = model_alexnet(img)
_, predicted = torch.max(output, 1)
# Create a prediction label and prediction probability dictionary for each prediction class
# This is the required format for Gradio's output parameter
pred_labels = 'Cat' if predicted.item() ==1 else 'Dog'
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels, pred_time
### 4. Gradio app ###
import gradio as gr
# Create title, description and article strings
title = "Classification Demo"
description = "Cat/Dog classification - Transfer Learning "
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type='filepath'), # what are the inputs?
outputs=[gr.Label(label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
#examples=example_list,
title=title,
description=description,)
# Launch the demo!
demo.launch()