ronnief1 commited on
Commit
cfa0071
·
1 Parent(s): 1403012

Create app2.py

Browse files
Files changed (1) hide show
  1. app2.py +274 -0
app2.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import albumentations as albu
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data import Dataset as BaseDataset
10
+ from catalyst.dl import SupervisedRunner
11
+ import segmentation_models_pytorch as smp
12
+ from io import StringIO
13
+
14
+
15
+ # streamlit run c:/Users/ronni/Downloads/polyp_seg_web_app/app.py
16
+
17
+
18
+ x_test_dir = 'test/test/images'
19
+ y_test_dir = 'test/test/masks'
20
+ ENCODER = 'mobilenet_v2'
21
+ ENCODER_WEIGHTS = 'imagenet'
22
+ CLASSES = ['polyp', 'background']
23
+ ACTIVATION = 'sigmoid'
24
+
25
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
26
+
27
+ def visualize(**images):
28
+ """Plot images in one row."""
29
+ n = len(images)
30
+ plt.figure(figsize=(16, 5))
31
+ for i, (name, image) in enumerate(images.items()):
32
+ plt.subplot(1, n, i + 1)
33
+ plt.xticks([])
34
+ plt.yticks([])
35
+ plt.title(' '.join(name.split('_')).title())
36
+ plt.imshow(image)
37
+ plt.savefig('x',dpi=400)
38
+ st.image('x.png')
39
+
40
+
41
+ def get_training_augmentation():
42
+ train_transform = [
43
+
44
+ albu.HorizontalFlip(p=0.5),
45
+
46
+ albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
47
+
48
+ albu.Resize(576, 736, always_apply=True, p=1),
49
+
50
+ albu.IAAAdditiveGaussianNoise(p=0.2),
51
+ albu.IAAPerspective(p=0.5),
52
+
53
+ albu.OneOf(
54
+ [
55
+ albu.CLAHE(p=1),
56
+ albu.RandomBrightness(p=1),
57
+ albu.RandomGamma(p=1),
58
+ ],
59
+ p=0.9,
60
+ ),
61
+
62
+ albu.OneOf(
63
+ [
64
+ albu.IAASharpen(p=1),
65
+ albu.Blur(blur_limit=3, p=1),
66
+ albu.MotionBlur(blur_limit=3, p=1),
67
+ ],
68
+ p=0.9,
69
+ ),
70
+
71
+ albu.OneOf(
72
+ [
73
+ albu.RandomContrast(p=1),
74
+ albu.HueSaturationValue(p=1),
75
+ ],
76
+ p=0.9,
77
+ ),
78
+ ]
79
+ return albu.Compose(train_transform)
80
+
81
+
82
+ def get_validation_augmentation():
83
+ """Add paddings to make image shape divisible by 32"""
84
+ test_transform = [
85
+ albu.Resize(576, 736)
86
+ ]
87
+ return albu.Compose(test_transform)
88
+
89
+
90
+ def to_tensor(x, **kwargs):
91
+ return x.transpose(2, 0, 1).astype('float32')
92
+
93
+ def get_preprocessing(preprocessing_fn):
94
+ """Construct preprocessing transform
95
+ Args:
96
+ preprocessing_fn (callbale): data normalization function
97
+ (can be specific for each pretrained neural network)
98
+ Return:
99
+ transform: albumentations.Compose
100
+ """
101
+
102
+ _transform = [
103
+ albu.Lambda(image=preprocessing_fn),
104
+ albu.Lambda(image=to_tensor, mask=to_tensor),
105
+ ]
106
+ return albu.Compose(_transform)
107
+
108
+ class Dataset(BaseDataset):
109
+ """Args:
110
+ images_dir (str): path to images folder
111
+ masks_dir (str): path to segmentation masks folder
112
+ class_values (list): values of classes to extract from segmentation mask
113
+ augmentation (albumentations.Compose): data transfromation pipeline
114
+ (e.g. flip, scale, etc.)
115
+ preprocessing (albumentations.Compose): data preprocessing
116
+ (e.g. noralization, shape manipulation, etc.)
117
+ """
118
+
119
+ CLASSES = ['polyp', 'background']
120
+
121
+ def __init__(
122
+ self,
123
+ images_dir,
124
+ masks_dir,
125
+ classes=None,
126
+ augmentation=None,
127
+ preprocessing=None,
128
+ single_file=False
129
+ ):
130
+
131
+ if single_file:
132
+ self.ids = images_dir
133
+ self.images_fps = os.path.join('test/test/images', self.ids)
134
+ self.masks_fps = os.path.join('test/test/masks', self.ids)
135
+ else:
136
+ self.ids = os.listdir(images_dir)
137
+ self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
138
+ self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
139
+
140
+ # convert str names to class values on masks
141
+ self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
142
+
143
+ self.augmentation = augmentation
144
+ self.preprocessing = preprocessing
145
+
146
+ def __getitem__(self, i):
147
+
148
+ # read data
149
+ image = cv2.imread(self.images_fps)
150
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151
+ mask = cv2.imread(self.masks_fps, 0)
152
+
153
+ mask[np.where(mask < 8)] = 0
154
+ mask[np.where(mask > 8)] = 255
155
+ # extract certain classes from mask (e.g. polyp)
156
+ masks = [(mask == v) for v in self.class_values]
157
+ mask = np.stack(masks, axis=-1).astype('float')
158
+
159
+ # apply augmentations
160
+ if self.augmentation:
161
+ sample = self.augmentation(image=image, mask=mask)
162
+ image, mask = sample['image'], sample['mask']
163
+
164
+ # apply preprocessing
165
+ if self.preprocessing:
166
+ sample = self.preprocessing(image=image, mask=mask)
167
+ image, mask = sample['image'], sample['mask']
168
+
169
+ return image, mask
170
+
171
+ def __len__(self):
172
+ return len(self.ids)
173
+
174
+ def model_infer(img_name):
175
+
176
+ model = smp.UnetPlusPlus(
177
+ encoder_name=ENCODER,
178
+ encoder_weights=ENCODER_WEIGHTS,
179
+ encoder_depth=5,
180
+ decoder_channels=(256, 128, 64, 32, 16),
181
+ classes=len(CLASSES),
182
+ activation=ACTIVATION,
183
+ decoder_attention_type=None,
184
+ )
185
+
186
+
187
+ model.load_state_dict(torch.load('best.pth', map_location=torch.device('cpu'))['model_state_dict'])
188
+ model.eval()
189
+
190
+ test_dataset = Dataset(
191
+ img_name,
192
+ img_name,
193
+ augmentation=get_validation_augmentation(),
194
+ preprocessing=get_preprocessing(preprocessing_fn),
195
+ classes=CLASSES,
196
+ single_file=True
197
+ )
198
+
199
+ test_dataloader = DataLoader(test_dataset)
200
+
201
+ loaders = {"infer": test_dataloader}
202
+
203
+ runner = SupervisedRunner()
204
+
205
+ logits = []
206
+ f = 0
207
+ for prediction in runner.predict_loader(model=model, loader=loaders['infer'],cpu=True):
208
+ if f < 3:
209
+ logits.append(prediction['logits'])
210
+ f = f + 1
211
+ else:
212
+ break
213
+
214
+ threshold = 0.5
215
+ break_at = 1
216
+
217
+ for i, (input, output) in enumerate(zip(
218
+ test_dataset, logits)):
219
+ image, mask = input
220
+
221
+ image_vis = image.transpose(1, 2, 0)
222
+ gt_mask = mask[0].astype('uint8')
223
+ pr_mask = (output[0].numpy() > threshold).astype('uint8')[0]
224
+ i = i + 1
225
+ if i >= break_at:
226
+ break
227
+
228
+ return image_vis, gt_mask, pr_mask
229
+ PAGE_TITLE = "Polyp Segmentation"
230
+
231
+ def file_selector(folder_path='.'):
232
+ filenames = os.listdir(folder_path)
233
+ selected_filename = st.selectbox('Select a file', filenames)
234
+ return os.path.join(folder_path, selected_filename)
235
+
236
+ def file_selector_ui():
237
+ folder_path = './test/test/images'
238
+ filename = file_selector(folder_path=folder_path)
239
+ printname = list(filename)
240
+ printname[filename.rfind('\\')] = '/'
241
+ st.write('You selected`%s`' % ''.join(printname))
242
+ return filename
243
+
244
+ def file_upload(folder_path='.'):
245
+ filenames = os.listdir(folder_path)
246
+ folder_path = './test/test/images'
247
+ uploaded_file = st.file_uploader("Choose a file")
248
+ filename = os.path.join(folder_path, uploaded_file.name)
249
+ printname = list(filename)
250
+ printname[filename.rfind('\\')] = '/'
251
+ st.write('You selected`%s`' % ''.join(printname))
252
+ return filename
253
+
254
+
255
+ def main():
256
+ st.set_page_config(page_title=PAGE_TITLE, layout="wide")
257
+ st.title(PAGE_TITLE)
258
+ image_path = file_selector_ui()
259
+ # image_path = file_upload()
260
+ image_path = os.path.abspath(image_path)
261
+ to_infer = image_path[image_path.rfind("\\") + 1:]
262
+
263
+ if os.path.isfile(image_path) is True:
264
+ _, file_extension = os.path.splitext(image_path)
265
+ if file_extension == ".jpg":
266
+ image_vis, gt_mask, pr_mask = model_infer(to_infer)
267
+ visualize(
268
+ image=image_vis,
269
+ #ground_truth_mask=gt_mask,
270
+ predicted_mask=pr_mask
271
+ )
272
+
273
+ if __name__ == "__main__":
274
+ main()