import gradio import torch import torchvision.transforms as T import numpy as np import matplotlib.pyplot as plt import os import random import PIL.Image as Image import time from model import create_fasterrcnn_model categories = [ { "id": 0, "name": "creatures", "supercategory": "none" }, { "id": 1, "name": "fish", "supercategory": "creatures" }, { "id": 2, "name": "jellyfish", "supercategory": "creatures" }, { "id": 3, "name": "penguin", "supercategory": "creatures" }, { "id": 4, "name": "puffin", "supercategory": "creatures" }, { "id": 5, "name": "shark", "supercategory": "creatures" }, { "id": 6, "name": "starfish", "supercategory": "creatures" }, { "id": 7, "name": "stingray", "supercategory": "creatures" } ] # 1, Create title, description and article strings title = "Ocean creatures detection Faster-R-CNN" description = "A Faster-RCNN-ResNet-50 backbone feature extractor computer vision model to classify images of fish, penguin, shark, etc" faster_rcnn = create_fasterrcnn_model( num_classes=8, # len(class_names) would also work ) # Load saved weights faster_rcnn.load_state_dict( torch.load( f="./third_train.pth", map_location=torch.device("cpu"), # load to CPU ) ) import random # Create predict function def predict(img): """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = time.time() device = 'cpu' transform = T.Compose([T.ToPILImage(),T.ToTensor()]) image_tensor = transform(img).to(device) image_tensor = image_tensor.unsqueeze(0) faster_rcnn.eval() with torch.no_grad(): predictions = faster_rcnn(image_tensor) pred_boxes = predictions[0]['boxes'].cpu().numpy() pred_scores = predictions[0]['scores'].cpu().numpy() pred_labels = predictions[0]['labels'].cpu().numpy() label_names = [categories[label]['name'] for label in pred_labels] fig, ax = plt.subplots(1) ax.imshow(img) for box, score, label_name in zip(pred_boxes, pred_scores, label_names): if score > 0.5: x1, y1, x2, y2 = box w, h = x2 - x1, y2 - y1 rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2) ax.add_patch(rect) ax.text(x1, y1, f'{label_name}: {score:.2f}', fontsize=5, color='white', bbox=dict(facecolor='red', alpha=0.2)) # save the figure to an image file random_name = str(random.randint(0,99)) img_path = f"./{random_name}.png" fig.savefig(img_path) # convert the figure to an image fig.canvas.draw() # Calculate the prediction time pred_time = round(time.time() - start_time, 5) # return the predicted label, the path to the saved image, and the prediction time return img_path, str(pred_time) ### 4. Gradio app ### # Get a list of all image file paths in the folder example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gradio.Interface(fn=predict, # mapping function from input to output inputs=gradio.Image(type= "numpy"), # what are the inputs? outputs=[gradio.outputs.Image(type= "filepath", label="Image with Bounding Boxes"), gradio.outputs.Label(type="auto", label="Prediction Time")], # our fn has two outputs # Create examples list from "examples/" directory examples=example_list, title=title, description=description) # Launch the demo! demo.launch(debug =True)