ronnief1 commited on
Commit
012083a
1 Parent(s): 89eba5c

first commit

Browse files
Files changed (2) hide show
  1. app.py +275 -0
  2. requirements.txt.txt +14 -0
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import albumentations as albu
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data import Dataset as BaseDataset
10
+ from catalyst.runners import SupervisedRunner
11
+ import segmentation_models_pytorch as smp
12
+ from io import StringIO
13
+
14
+ # streamlit run c:/Users/ronni/Downloads/polyp_seg_web_app/app.py
15
+
16
+
17
+ x_test_dir = 'test/test/images'
18
+ y_test_dir = 'test/test/masks'
19
+ ENCODER = 'mobilenet_v2'
20
+ ENCODER_WEIGHTS = 'imagenet'
21
+ CLASSES = ['polyp', 'background']
22
+ ACTIVATION = 'sigmoid'
23
+
24
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
25
+
26
+ def visualize(**images):
27
+ """Plot images in one row."""
28
+ n = len(images)
29
+ plt.figure(figsize=(16, 5))
30
+ for i, (name, image) in enumerate(images.items()):
31
+ plt.subplot(1, n, i + 1)
32
+ plt.xticks([])
33
+ plt.yticks([])
34
+ plt.title(' '.join(name.split('_')).title())
35
+ plt.imshow(image)
36
+ plt.savefig('x',dpi=400)
37
+ st.image('x.png')
38
+
39
+
40
+ def get_training_augmentation():
41
+ train_transform = [
42
+
43
+ albu.HorizontalFlip(p=0.5),
44
+
45
+ albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
46
+
47
+ albu.Resize(576, 736, always_apply=True, p=1),
48
+
49
+ albu.IAAAdditiveGaussianNoise(p=0.2),
50
+ albu.IAAPerspective(p=0.5),
51
+
52
+ albu.OneOf(
53
+ [
54
+ albu.CLAHE(p=1),
55
+ albu.RandomBrightness(p=1),
56
+ albu.RandomGamma(p=1),
57
+ ],
58
+ p=0.9,
59
+ ),
60
+
61
+ albu.OneOf(
62
+ [
63
+ albu.IAASharpen(p=1),
64
+ albu.Blur(blur_limit=3, p=1),
65
+ albu.MotionBlur(blur_limit=3, p=1),
66
+ ],
67
+ p=0.9,
68
+ ),
69
+
70
+ albu.OneOf(
71
+ [
72
+ albu.RandomContrast(p=1),
73
+ albu.HueSaturationValue(p=1),
74
+ ],
75
+ p=0.9,
76
+ ),
77
+ ]
78
+ return albu.Compose(train_transform)
79
+
80
+
81
+ def get_validation_augmentation():
82
+ """Add paddings to make image shape divisible by 32"""
83
+ test_transform = [
84
+ albu.Resize(576, 736)
85
+ ]
86
+ return albu.Compose(test_transform)
87
+
88
+
89
+ def to_tensor(x, **kwargs):
90
+ return x.transpose(2, 0, 1).astype('float32')
91
+
92
+ def get_preprocessing(preprocessing_fn):
93
+ """Construct preprocessing transform
94
+
95
+ Args:
96
+ preprocessing_fn (callbale): data normalization function
97
+ (can be specific for each pretrained neural network)
98
+ Return:
99
+ transform: albumentations.Compose
100
+
101
+ """
102
+
103
+ _transform = [
104
+ albu.Lambda(image=preprocessing_fn),
105
+ albu.Lambda(image=to_tensor, mask=to_tensor),
106
+ ]
107
+ return albu.Compose(_transform)
108
+
109
+ class Dataset(BaseDataset):
110
+ """Args:
111
+ images_dir (str): path to images folder
112
+ masks_dir (str): path to segmentation masks folder
113
+ class_values (list): values of classes to extract from segmentation mask
114
+ augmentation (albumentations.Compose): data transfromation pipeline
115
+ (e.g. flip, scale, etc.)
116
+ preprocessing (albumentations.Compose): data preprocessing
117
+ (e.g. noralization, shape manipulation, etc.)
118
+
119
+ """
120
+
121
+ CLASSES = ['polyp', 'background']
122
+
123
+ def __init__(
124
+ self,
125
+ images_dir,
126
+ masks_dir,
127
+ classes=None,
128
+ augmentation=None,
129
+ preprocessing=None,
130
+ single_file=False
131
+ ):
132
+
133
+ if single_file:
134
+ self.ids = images_dir
135
+ self.images_fps = os.path.join('test/test/images', self.ids)
136
+ self.masks_fps = os.path.join('test/test/masks', self.ids)
137
+ else:
138
+ self.ids = os.listdir(images_dir)
139
+ self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
140
+ self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
141
+
142
+ # convert str names to class values on masks
143
+ self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
144
+
145
+ self.augmentation = augmentation
146
+ self.preprocessing = preprocessing
147
+
148
+ def __getitem__(self, i):
149
+
150
+ # read data
151
+ image = cv2.imread(self.images_fps)
152
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
153
+ mask = cv2.imread(self.masks_fps, 0)
154
+ mask[np.where(mask < 8)] = 0
155
+ mask[np.where(mask > 8)] = 255
156
+ # extract certain classes from mask (e.g. polyp)
157
+ masks = [(mask == v) for v in self.class_values]
158
+ mask = np.stack(masks, axis=-1).astype('float')
159
+
160
+ # apply augmentations
161
+ if self.augmentation:
162
+ sample = self.augmentation(image=image, mask=mask)
163
+ image, mask = sample['image'], sample['mask']
164
+
165
+ # apply preprocessing
166
+ if self.preprocessing:
167
+ sample = self.preprocessing(image=image, mask=mask)
168
+ image, mask = sample['image'], sample['mask']
169
+
170
+ return image, mask
171
+
172
+ def __len__(self):
173
+ return len(self.ids)
174
+
175
+ def model_infer(img_name):
176
+
177
+ model = smp.UnetPlusPlus(
178
+ encoder_name=ENCODER,
179
+ encoder_weights=ENCODER_WEIGHTS,
180
+ encoder_depth=5,
181
+ decoder_channels=(256, 128, 64, 32, 16),
182
+ classes=len(CLASSES),
183
+ activation=ACTIVATION,
184
+ decoder_attention_type=None,
185
+ )
186
+
187
+
188
+ model.load_state_dict(torch.load('best.pth', map_location=torch.device('cpu'))['model_state_dict'])
189
+ model.eval()
190
+
191
+ test_dataset = Dataset(
192
+ img_name,
193
+ img_name,
194
+ augmentation=get_validation_augmentation(),
195
+ preprocessing=get_preprocessing(preprocessing_fn),
196
+ classes=CLASSES,
197
+ single_file=True
198
+ )
199
+
200
+ test_dataloader = DataLoader(test_dataset)
201
+
202
+ loaders = {"infer": test_dataloader}
203
+
204
+ runner = SupervisedRunner()
205
+
206
+ logits = []
207
+ f = 0
208
+ for prediction in runner.predict_loader(model=model, loader=loaders['infer'],cpu=True):
209
+ if f < 3:
210
+ logits.append(prediction['logits'])
211
+ f = f + 1
212
+ else:
213
+ break
214
+
215
+ threshold = 0.5
216
+ break_at = 1
217
+
218
+ for i, (input, output) in enumerate(zip(
219
+ test_dataset, logits)):
220
+ image, mask = input
221
+
222
+ image_vis = image.transpose(1, 2, 0)
223
+ gt_mask = mask[0].astype('uint8')
224
+ pr_mask = (output[0].numpy() > threshold).astype('uint8')[0]
225
+ i = i + 1
226
+ if i >= break_at:
227
+ break
228
+
229
+ return image_vis, gt_mask, pr_mask
230
+ PAGE_TITLE = "Polyp Segmentation"
231
+
232
+ def file_selector(folder_path='.'):
233
+ filenames = os.listdir(folder_path)
234
+ selected_filename = st.selectbox('Select a file', filenames)
235
+ return os.path.join(folder_path, selected_filename)
236
+
237
+ def file_selector_ui():
238
+ folder_path = './test/test/images'
239
+ filename = file_selector(folder_path=folder_path)
240
+ printname = list(filename)
241
+ printname[filename.rfind('\\')] = '/'
242
+ st.write('You selected`%s`' % ''.join(printname))
243
+ return filename
244
+
245
+ def file_upload(folder_path='.'):
246
+ filenames = os.listdir(folder_path)
247
+ folder_path = './test/test/images'
248
+ uploaded_file = st.file_uploader("Choose a file")
249
+ filename = os.path.join(folder_path, uploaded_file.name)
250
+ printname = list(filename)
251
+ printname[filename.rfind('\\')] = '/'
252
+ st.write('You selected`%s`' % ''.join(printname))
253
+ return filename
254
+
255
+
256
+ def main():
257
+ st.set_page_config(page_title=PAGE_TITLE, layout="wide")
258
+ st.title(PAGE_TITLE)
259
+ image_path = file_selector_ui()
260
+ # image_path = file_upload()
261
+ image_path = os.path.abspath(image_path)
262
+ to_infer = image_path[image_path.rfind("\\") + 1:]
263
+
264
+ if os.path.isfile(image_path) is True:
265
+ _, file_extension = os.path.splitext(image_path)
266
+ if file_extension == ".jpg":
267
+ image_vis, gt_mask, pr_mask = model_infer(to_infer)
268
+ visualize(
269
+ image=image_vis,
270
+ ground_truth_mask=gt_mask,
271
+ predicted_mask=pr_mask
272
+ )
273
+
274
+ if __name__ == "__main__":
275
+ main()
requirements.txt.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ catalyst==19.04rc1
2
+ git+https://github.com/albu/albumentations@bdd6a4e
3
+ git+https://github.com/qubvel/segmentation_models.pytorch
4
+ os
5
+ numpy
6
+ cv2
7
+ matplotlib
8
+ torch
9
+ albumentations
10
+ segmentation_models_pytorch
11
+ collections
12
+ splitfolders
13
+ gc
14
+ streamlit