Munzali's picture
Update app.py
07f0645 verified
### 1. Imports and class names setup ###
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from model import create_mobilenet_model
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
class_names = ['bacterial', 'blast', 'brownspot', 'tungro']
### 2. Model and transforms preparation ###
mobilenet, manual_transforms = create_mobilenet_model(
num_classes=4
)
mobilenet.load_state_dict(
torch.load(
f="mobilenet_5_epochs.pth",
map_location=torch.device("cpu"),
)
)
### 3. Predict function ###
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
img = manual_transforms(img).unsqueeze(0)
mobilenet.eval()
with torch.inference_mode():
pred_probs = torch.softmax(mobilenet(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
### 4. Gradio app ###
# Create a Blocks app (only one!)
with gr.Blocks() as gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
Rice Disease Classification - MobileNet Model
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
Follow me for more!
<!-- <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | -->
<a href='https://github.com/ExplorerGumel' target='_blank'>Github</a> |
<a href='https://www.linkedin.com/in/munzali-alhassan/' target='_blank'>Linkedin</a> |
<!-- <a href='https://www.huggingface.co/kadirnar/' target='_blank'>HuggingFace</a> -->
</h3>
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="Upload Image")
infer = gr.Button(value="Predict")
# Examples linked to the input component 'image'
example_list = [["examples/" + example] for example in os.listdir("examples")]
gr.Examples(
examples=example_list,
inputs=[image] # Pass the actual input component
)
with gr.Column():
label = gr.Label(num_top_classes=4, label="Predictions")
pred_time = gr.Number(label="Prediction Time (s)")
infer.click(
fn=predict,
inputs=[image],
outputs=[label, pred_time]
)
# Launch the app
gradio_app.launch(debug=True, share=True)
# gradio_app.launch(debug=True, share=True)
# # Create title, description and article strings
# title = "RICE DISEASES CLASSIFICATION"
# description = "A MobileNetV2 feature extractor computer vision model to classify images of Rice diseases."
# article = "Created by Munzali Alhassan."
# # Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]
# # Create the Gradio demo
# demo = gr.Interface(fn=predict, # mapping function from input to output
# inputs=gr.Image(type="pil"), # what are the inputs?
# outputs=[gr.Label(num_top_classes=4, label="Predictions"), # what are the outputs?
# gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# # Create examples list from "examples/" directory
# examples=example_list,
# title=title,
# description=description,
# article=article)
# # Launch the demo!
# demo.launch(share=True)