aaravlovescodes commited on
Commit
c7cd1d7
·
verified ·
1 Parent(s): dc1b882

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms, datasets
6
+ from albumentations import Compose, HorizontalFlip, ShiftScaleRotate, Resize, Normalize
7
+ from albumentations.pytorch import ToTensorV2
8
+ import timm
9
+ import gradio as gr
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ # Hyperparameters
14
+ h = {
15
+ "num_epochs": 10,
16
+ "batch_size": 64,
17
+ "image_size": 224,
18
+ "lr": 0.001,
19
+ "model": "efficientnetv2",
20
+ "scheduler": "CosineAnnealingLR10",
21
+ "balance": True,
22
+ "early_stopping_patience": 5
23
+ }
24
+
25
+ # Custom Dataset and DataModule for PyTorch Lightning
26
+ class CustomImageFolder(torch.utils.data.Dataset):
27
+ def __init__(self, root, transform=None):
28
+ self.dataset = datasets.ImageFolder(root)
29
+ self.transform = transform
30
+
31
+ def __getitem__(self, index):
32
+ image, label = self.dataset[index]
33
+ if self.transform:
34
+ image = self.transform(image=np.array(image))["image"]
35
+ return image, label
36
+
37
+ def __len__(self):
38
+ return len(self.dataset)
39
+
40
+ class PneumoniaDataModule(pl.LightningDataModule):
41
+ def __init__(self, h, data_dir):
42
+ super().__init__()
43
+ self.h = h
44
+ self.data_dir = data_dir
45
+
46
+ def setup(self, stage=None):
47
+ train_transform = Compose([
48
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=20),
49
+ HorizontalFlip(),
50
+ Resize(self.h["image_size"], self.h["image_size"]),
51
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
52
+ ToTensorV2()
53
+ ])
54
+
55
+ self.train_dataset = CustomImageFolder(self.data_dir + "/train", transform=train_transform)
56
+
57
+ def train_dataloader(self):
58
+ return DataLoader(self.train_dataset, batch_size=self.h["batch_size"], shuffle=True)
59
+
60
+ # Model definition using LightningModule
61
+ class PneumoniaModel(pl.LightningModule):
62
+ def __init__(self, h):
63
+ super().__init__()
64
+ self.h = h
65
+ self.model = timm.create_model("tf_efficientnetv2_b0", pretrained=True, num_classes=2)
66
+ self.criterion = nn.CrossEntropyLoss()
67
+
68
+ def forward(self, x):
69
+ return self.model(x)
70
+
71
+ def training_step(self, batch, batch_idx):
72
+ inputs, labels = batch
73
+ outputs = self(inputs)
74
+ loss = self.criterion(outputs, labels)
75
+ return loss
76
+
77
+ def configure_optimizers(self):
78
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.h["lr"])
79
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.h["num_epochs"], eta_min=self.h["lr"] * 0.1)
80
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
81
+
82
+ # Load model after training
83
+ def load_model(h):
84
+ model = PneumoniaModel(h)
85
+ model.load_state_dict(torch.load("pneumonia_model.pth", map_location=torch.device('cpu')))
86
+ model.eval()
87
+ return model
88
+
89
+ trained_model = load_model(h)
90
+
91
+ # Gradio Prediction Function
92
+ def predict_pneumonia(image):
93
+ transform = Compose([
94
+ Resize(h["image_size"], h["image_size"]),
95
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
96
+ ToTensorV2()
97
+ ])
98
+
99
+ # Preprocess the image
100
+ image = np.array(image)
101
+ image = transform(image=image)["image"]
102
+ image = image.unsqueeze(0) # Add batch dimension
103
+
104
+ # Predict with the model
105
+ with torch.no_grad():
106
+ outputs = trained_model(image)
107
+ prediction = torch.argmax(outputs, dim=1).item()
108
+
109
+ # Map prediction to label
110
+ label = "Pneumonia Detected" if prediction == 1 else "Normal"
111
+ return label
112
+
113
+ # Gradio Interface
114
+ input_image = gr.inputs.Image(type="pil", label="Upload Chest X-ray Image")
115
+ output_label = gr.outputs.Label(label="Diagnosis")
116
+
117
+ app = gr.Interface(
118
+ fn=predict_pneumonia,
119
+ inputs=input_image,
120
+ outputs=output_label,
121
+ title="Pneumonia Detection",
122
+ description="Upload a chest X-ray image to detect potential pneumonia using AI."
123
+ )
124
+
125
+ # Launch the app
126
+ app.launch()