Spaces:
Runtime error
Runtime error
# Importing Libraries | |
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
from model import Model | |
import albumentations as A | |
# Creating a model instance | |
efficientnet_b5_model = Model() | |
efficientnet_b5_model = torch.nn.DataParallel( | |
efficientnet_b5_model) # Must wrap our model in nn.DataParallel() | |
# if used multi-gpu's to train the model otherwise we would get state_dict keys mismatch error. | |
efficientnet_b5_model.load_state_dict( | |
torch.load( | |
f='efficientnet_b5_checkpoint_fold_0.pt', | |
map_location=torch.device("cpu") | |
) | |
) | |
# Predict on a single image | |
def predict_on_single_image(img): | |
""" | |
Function takes an image, transforms for | |
model training like normalizing the statistics | |
of the image. Converting the numpy array into | |
torch tensor and passing through the model | |
to get the prediction probability of a patient | |
having melanoma. | |
""" | |
img = np.array(img) | |
transforms = A.Compose([A.Resize(512, 512), | |
A.Normalize(mean=(0.485, 0.456, 0.406), | |
std=(0.229, 0.224, 0.225), | |
max_pixel_value=255.0, | |
always_apply=True | |
)] | |
) | |
img = transforms(image=img)['image'] | |
image = np.transpose(img, (2, 0, 1)).astype(np.float32) | |
image = torch.tensor(image, dtype=torch.float).unsqueeze(dim=0) | |
efficientnet_b5_model.eval() | |
with torch.inference_mode(): | |
probs = torch.sigmoid(efficientnet_b5_model(image)) | |
prob_of_melanoma = probs[0].item() | |
prob_of_not_having_melanoma = 1 - prob_of_melanoma | |
pred_label = {"Probability of Having Melanoma": prob_of_melanoma, | |
"Probability of Not having Melanoma": prob_of_not_having_melanoma} | |
return pred_label | |
# Gradio App | |
# Examples directory path | |
melanoma_app_examples_path = "examples" | |
# Creating the title and description strings | |
title = "Melanoma Cancer Detection App" | |
description = 'An efficientnet-b5 model that predicts the probability of a patient having melanoma skin cancer or not.' | |
example_list = [["examples/" + example] | |
for example in os.listdir(melanoma_app_examples_path)] | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict_on_single_image, | |
inputs=gr.Image(type='pil'), | |
outputs=[gr.Label(label='Probabilities')], | |
examples=example_list, title=title, | |
description=description) | |
# Launch the demo! | |
demo.launch() | |