Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision import transforms, datasets
|
6 |
+
from albumentations import Compose, HorizontalFlip, ShiftScaleRotate, Resize, Normalize
|
7 |
+
from albumentations.pytorch import ToTensorV2
|
8 |
+
import timm
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
# Hyperparameters
|
14 |
+
h = {
|
15 |
+
"num_epochs": 10,
|
16 |
+
"batch_size": 64,
|
17 |
+
"image_size": 224,
|
18 |
+
"lr": 0.001,
|
19 |
+
"model": "efficientnetv2",
|
20 |
+
"scheduler": "CosineAnnealingLR10",
|
21 |
+
"balance": True,
|
22 |
+
"early_stopping_patience": 5
|
23 |
+
}
|
24 |
+
|
25 |
+
# Custom Dataset and DataModule for PyTorch Lightning
|
26 |
+
class CustomImageFolder(torch.utils.data.Dataset):
|
27 |
+
def __init__(self, root, transform=None):
|
28 |
+
self.dataset = datasets.ImageFolder(root)
|
29 |
+
self.transform = transform
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
image, label = self.dataset[index]
|
33 |
+
if self.transform:
|
34 |
+
image = self.transform(image=np.array(image))["image"]
|
35 |
+
return image, label
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.dataset)
|
39 |
+
|
40 |
+
class PneumoniaDataModule(pl.LightningDataModule):
|
41 |
+
def __init__(self, h, data_dir):
|
42 |
+
super().__init__()
|
43 |
+
self.h = h
|
44 |
+
self.data_dir = data_dir
|
45 |
+
|
46 |
+
def setup(self, stage=None):
|
47 |
+
train_transform = Compose([
|
48 |
+
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=20),
|
49 |
+
HorizontalFlip(),
|
50 |
+
Resize(self.h["image_size"], self.h["image_size"]),
|
51 |
+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
52 |
+
ToTensorV2()
|
53 |
+
])
|
54 |
+
|
55 |
+
self.train_dataset = CustomImageFolder(self.data_dir + "/train", transform=train_transform)
|
56 |
+
|
57 |
+
def train_dataloader(self):
|
58 |
+
return DataLoader(self.train_dataset, batch_size=self.h["batch_size"], shuffle=True)
|
59 |
+
|
60 |
+
# Model definition using LightningModule
|
61 |
+
class PneumoniaModel(pl.LightningModule):
|
62 |
+
def __init__(self, h):
|
63 |
+
super().__init__()
|
64 |
+
self.h = h
|
65 |
+
self.model = timm.create_model("tf_efficientnetv2_b0", pretrained=True, num_classes=2)
|
66 |
+
self.criterion = nn.CrossEntropyLoss()
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.model(x)
|
70 |
+
|
71 |
+
def training_step(self, batch, batch_idx):
|
72 |
+
inputs, labels = batch
|
73 |
+
outputs = self(inputs)
|
74 |
+
loss = self.criterion(outputs, labels)
|
75 |
+
return loss
|
76 |
+
|
77 |
+
def configure_optimizers(self):
|
78 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.h["lr"])
|
79 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.h["num_epochs"], eta_min=self.h["lr"] * 0.1)
|
80 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
81 |
+
|
82 |
+
# Load model after training
|
83 |
+
def load_model(h):
|
84 |
+
model = PneumoniaModel(h)
|
85 |
+
model.load_state_dict(torch.load("pneumonia_model.pth", map_location=torch.device('cpu')))
|
86 |
+
model.eval()
|
87 |
+
return model
|
88 |
+
|
89 |
+
trained_model = load_model(h)
|
90 |
+
|
91 |
+
# Gradio Prediction Function
|
92 |
+
def predict_pneumonia(image):
|
93 |
+
transform = Compose([
|
94 |
+
Resize(h["image_size"], h["image_size"]),
|
95 |
+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
96 |
+
ToTensorV2()
|
97 |
+
])
|
98 |
+
|
99 |
+
# Preprocess the image
|
100 |
+
image = np.array(image)
|
101 |
+
image = transform(image=image)["image"]
|
102 |
+
image = image.unsqueeze(0) # Add batch dimension
|
103 |
+
|
104 |
+
# Predict with the model
|
105 |
+
with torch.no_grad():
|
106 |
+
outputs = trained_model(image)
|
107 |
+
prediction = torch.argmax(outputs, dim=1).item()
|
108 |
+
|
109 |
+
# Map prediction to label
|
110 |
+
label = "Pneumonia Detected" if prediction == 1 else "Normal"
|
111 |
+
return label
|
112 |
+
|
113 |
+
# Gradio Interface
|
114 |
+
input_image = gr.inputs.Image(type="pil", label="Upload Chest X-ray Image")
|
115 |
+
output_label = gr.outputs.Label(label="Diagnosis")
|
116 |
+
|
117 |
+
app = gr.Interface(
|
118 |
+
fn=predict_pneumonia,
|
119 |
+
inputs=input_image,
|
120 |
+
outputs=output_label,
|
121 |
+
title="Pneumonia Detection",
|
122 |
+
description="Upload a chest X-ray image to detect potential pneumonia using AI."
|
123 |
+
)
|
124 |
+
|
125 |
+
# Launch the app
|
126 |
+
app.launch()
|