Spaces:
Build error
Build error
import gradio | |
import torch | |
import torchvision.transforms as T | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import random | |
import PIL.Image as Image | |
import time | |
from model import create_fasterrcnn_model | |
categories = [ | |
{ | |
"id": 0, | |
"name": "creatures", | |
"supercategory": "none" | |
}, | |
{ | |
"id": 1, | |
"name": "fish", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 2, | |
"name": "jellyfish", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 3, | |
"name": "penguin", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 4, | |
"name": "puffin", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 5, | |
"name": "shark", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 6, | |
"name": "starfish", | |
"supercategory": "creatures" | |
}, | |
{ | |
"id": 7, | |
"name": "stingray", | |
"supercategory": "creatures" | |
} | |
] | |
# 1, Create title, description and article strings | |
title = "Ocean creatures detection Faster-R-CNN" | |
description = "A Faster-RCNN-ResNet-50 backbone feature extractor computer vision model to classify images of fish, penguin, shark, etc" | |
faster_rcnn = create_fasterrcnn_model( | |
num_classes=8, # len(class_names) would also work | |
) | |
# Load saved weights | |
faster_rcnn.load_state_dict( | |
torch.load( | |
f="./third_train.pth", | |
map_location=torch.device("cpu"), # load to CPU | |
) | |
) | |
import random | |
# Create predict function | |
def predict(img): | |
"""Transforms and performs a prediction on img and returns prediction and time taken. | |
""" | |
# Start the timer | |
start_time = time.time() | |
device = 'cpu' | |
transform = T.Compose([T.ToPILImage(),T.ToTensor()]) | |
image_tensor = transform(img).to(device) | |
image_tensor = image_tensor.unsqueeze(0) | |
faster_rcnn.eval() | |
with torch.no_grad(): | |
predictions = faster_rcnn(image_tensor) | |
pred_boxes = predictions[0]['boxes'].cpu().numpy() | |
pred_scores = predictions[0]['scores'].cpu().numpy() | |
pred_labels = predictions[0]['labels'].cpu().numpy() | |
label_names = [categories[label]['name'] for label in pred_labels] | |
fig, ax = plt.subplots(1) | |
ax.imshow(img) | |
for box, score, label_name in zip(pred_boxes, pred_scores, label_names): | |
if score > 0.5: | |
x1, y1, x2, y2 = box | |
w, h = x2 - x1, y2 - y1 | |
rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2) | |
ax.add_patch(rect) | |
ax.text(x1, y1, f'{label_name}: {score:.2f}', fontsize=5, color='white', bbox=dict(facecolor='red', alpha=0.2)) | |
# save the figure to an image file | |
random_name = str(random.randint(0,99)) | |
img_path = f"./{random_name}.png" | |
fig.savefig(img_path) | |
# convert the figure to an image | |
fig.canvas.draw() | |
# Calculate the prediction time | |
pred_time = round(time.time() - start_time, 5) | |
# return the predicted label, the path to the saved image, and the prediction time | |
return img_path, str(pred_time) | |
### 4. Gradio app ### | |
# Get a list of all image file paths in the folder | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create the Gradio demo | |
demo = gradio.Interface(fn=predict, # mapping function from input to output | |
inputs=gradio.Image(type= "numpy"), # what are the inputs? | |
outputs=[gradio.outputs.Image(type= "filepath", label="Image with Bounding Boxes"), | |
gradio.outputs.Label(type="auto", label="Prediction Time")], # our fn has two outputs | |
# Create examples list from "examples/" directory | |
examples=example_list, | |
title=title, | |
description=description) | |
# Launch the demo! | |
demo.launch(debug =True) |