Chirag1994's picture
demo.launch() fix
422fb0d
raw
history blame
2.62 kB
# Importing Libraries
import os
import torch
import numpy as np
import gradio as gr
from model import Model
import albumentations as A
# Creating a model instance
efficientnet_b5_model = Model()
efficientnet_b5_model = torch.nn.DataParallel(
efficientnet_b5_model) # Must wrap our model in nn.DataParallel()
# if used multi-gpu's to train the model otherwise we would get state_dict keys mismatch error.
efficientnet_b5_model.load_state_dict(
torch.load(
f='efficientnet_b5_checkpoint_fold_0.pt',
map_location=torch.device("cpu")
)
)
# Predict on a single image
def predict_on_single_image(img):
"""
Function takes an image, transforms for
model training like normalizing the statistics
of the image. Converting the numpy array into
torch tensor and passing through the model
to get the prediction probability of a patient
having melanoma.
"""
img = np.array(img)
transforms = A.Compose([A.Resize(512, 512),
A.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_pixel_value=255.0,
always_apply=True
)]
)
img = transforms(image=img)['image']
image = np.transpose(img, (2, 0, 1)).astype(np.float32)
image = torch.tensor(image, dtype=torch.float).unsqueeze(dim=0)
efficientnet_b5_model.eval()
with torch.inference_mode():
probs = torch.sigmoid(efficientnet_b5_model(image))
prob_of_melanoma = probs[0].item()
prob_of_not_having_melanoma = 1 - prob_of_melanoma
pred_label = {"Probability of Having Melanoma": prob_of_melanoma,
"Probability of Not having Melanoma": prob_of_not_having_melanoma}
return pred_label
# Gradio App
# Examples directory path
melanoma_app_examples_path = "examples"
# Creating the title and description strings
title = "Melanoma Cancer Detection App"
description = 'An efficientnet-b5 model that predicts the probability of a patient having melanoma skin cancer or not.'
example_list = [["examples/" + example]
for example in os.listdir(melanoma_app_examples_path)]
# Create the Gradio demo
demo = gr.Interface(fn=predict_on_single_image,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(label='Probabilities')],
examples=example_list, title=title,
description=description)
# Launch the demo!
demo.launch()