|
import torch |
|
import gradio as gr |
|
from src.model import DRModel |
|
from torchvision import transforms as T |
|
|
|
CHECKPOINT_PATH = "artifacts/dr-model.ckpt" |
|
model = DRModel.load_from_checkpoint(CHECKPOINT_PATH, map_location="cpu") |
|
model.eval() |
|
|
|
labels = { |
|
0: "No DR", |
|
1: "Mild", |
|
2: "Moderate", |
|
3: "Severe", |
|
4: "Proliferative DR", |
|
} |
|
|
|
transform = T.Compose( |
|
[ |
|
T.Resize((224, 224)), |
|
T.ToTensor(), |
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
|
|
def predict(input_img): |
|
input_img = transform(input_img).unsqueeze(0) |
|
with torch.no_grad(): |
|
prediction = torch.nn.functional.softmax(model(input_img)[0], dim=0) |
|
confidences = {labels[i]: float(prediction[i]) for i in labels} |
|
return confidences |
|
|
|
|
|
|
|
dr_app = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(), |
|
title="Diabetic Retinopathy Detection App", |
|
description="Welcome to our Diabetic Retinopathy Detection App! \ |
|
This app utilizes deep learning models to detect diabetic retinopathy in retinal images.\ |
|
Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.", |
|
examples=[ |
|
"data/sample/10_left.jpeg", |
|
"data/sample/10_right.jpeg", |
|
"data/sample/15_left.jpeg", |
|
"data/sample/16_right.jpeg", |
|
], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
dr_app.launch() |
|
|