oliverlevn's picture
remove resize
bb029d9
import gradio
import torch
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import os
import random
import PIL.Image as Image
import time
from model import create_fasterrcnn_model
categories = [
{
"id": 0,
"name": "creatures",
"supercategory": "none"
},
{
"id": 1,
"name": "fish",
"supercategory": "creatures"
},
{
"id": 2,
"name": "jellyfish",
"supercategory": "creatures"
},
{
"id": 3,
"name": "penguin",
"supercategory": "creatures"
},
{
"id": 4,
"name": "puffin",
"supercategory": "creatures"
},
{
"id": 5,
"name": "shark",
"supercategory": "creatures"
},
{
"id": 6,
"name": "starfish",
"supercategory": "creatures"
},
{
"id": 7,
"name": "stingray",
"supercategory": "creatures"
}
]
# 1, Create title, description and article strings
title = "Ocean creatures detection Faster-R-CNN"
description = "A Faster-RCNN-ResNet-50 backbone feature extractor computer vision model to classify images of fish, penguin, shark, etc"
faster_rcnn = create_fasterrcnn_model(
num_classes=8, # len(class_names) would also work
)
# Load saved weights
faster_rcnn.load_state_dict(
torch.load(
f="./third_train.pth",
map_location=torch.device("cpu"), # load to CPU
)
)
import random
# Create predict function
def predict(img):
"""Transforms and performs a prediction on img and returns prediction and time taken.
"""
# Start the timer
start_time = time.time()
device = 'cpu'
transform = T.Compose([T.ToPILImage(),T.ToTensor()])
image_tensor = transform(img).to(device)
image_tensor = image_tensor.unsqueeze(0)
faster_rcnn.eval()
with torch.no_grad():
predictions = faster_rcnn(image_tensor)
pred_boxes = predictions[0]['boxes'].cpu().numpy()
pred_scores = predictions[0]['scores'].cpu().numpy()
pred_labels = predictions[0]['labels'].cpu().numpy()
label_names = [categories[label]['name'] for label in pred_labels]
fig, ax = plt.subplots(1)
ax.imshow(img)
for box, score, label_name in zip(pred_boxes, pred_scores, label_names):
if score > 0.5:
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2)
ax.add_patch(rect)
ax.text(x1, y1, f'{label_name}: {score:.2f}', fontsize=5, color='white', bbox=dict(facecolor='red', alpha=0.2))
# save the figure to an image file
random_name = str(random.randint(0,99))
img_path = f"./{random_name}.png"
fig.savefig(img_path)
# convert the figure to an image
fig.canvas.draw()
# Calculate the prediction time
pred_time = round(time.time() - start_time, 5)
# return the predicted label, the path to the saved image, and the prediction time
return img_path, str(pred_time)
### 4. Gradio app ###
# Get a list of all image file paths in the folder
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gradio.Interface(fn=predict, # mapping function from input to output
inputs=gradio.Image(type= "numpy"), # what are the inputs?
outputs=[gradio.outputs.Image(type= "filepath", label="Image with Bounding Boxes"),
gradio.outputs.Label(type="auto", label="Prediction Time")], # our fn has two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description)
# Launch the demo!
demo.launch(debug =True)