dtrujillo commited on
Commit
a0dd22b
·
1 Parent(s): 182caf4

added app and dataset libs

Browse files
Files changed (2) hide show
  1. app.py +10 -0
  2. dataset.py +367 -0
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset import FrameReaderDataset
2
+ import numpy as np
3
+
4
+ def generate_frames(ffile,dfile):
5
+ frd = FrameReaderDataset(ffile=ffile,dfile=dfile)
6
+ frames=frd.get_frames()
7
+ return frames
8
+
9
+ interface=gr.Interface(fn=generate_frames, inputs=[gr.inputs.File(label="frame file"),gr.inputs.File(label="dark file")],
10
+ outputs=gr.outputs.File('peaks.npy'))
dataset.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import numpy as np
3
+ import h5py, torch, random, logging
4
+ from skimage.feature import peak_local_max
5
+ from skimage import measure
6
+ from skimage.measure import label, regionprops
7
+ import os
8
+ #from cc_torch import connected_components_labeling
9
+ from torchvision import transforms
10
+ from time import time
11
+ import gc
12
+ #from torch.utils.data import TensorDataset,Dataloader
13
+
14
+ def connected_components_torch(images,crop_size=15,NrPixels = 2048):
15
+ window=int(crop_size/2)
16
+ start = torch.cuda.Event(enable_timing=True)
17
+ end = torch.cuda.Event(enable_timing=True)
18
+
19
+ ccs =[]
20
+
21
+ start.record()
22
+
23
+ for image in images:
24
+ cc_out = connected_components_labeling(image).cpu().numpy()
25
+ ccs.append(cc_out)
26
+
27
+ end.record()
28
+ torch.cuda.synchronize()
29
+ print('cc_time: ', start.elapsed_time(end)/1000)
30
+
31
+ return ccs
32
+
33
+ def region_props(ccs,images,crop_size=15,NrPixels = 2048):
34
+ window=int(crop_size/2)
35
+
36
+ masks =[]
37
+ centers =[]
38
+
39
+ i=0
40
+ start = time()
41
+ for cc in ccs:
42
+ for region_nr,region in enumerate(regionprops(cc)):
43
+ if region.area > 4 or region.area < 150:
44
+ x,y = region.centroid
45
+ start_x = int(x)-window
46
+ end_x = int(x)+window+1
47
+ start_y = int(y)-window
48
+ end_y = int(y)+window+1
49
+ if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
50
+ continue
51
+ #sub_img = np.copy(images[i])
52
+ #sub_img[ccs != region_nr+1] = 0
53
+ #sub_img = sub_img[start_y:end_y,start_x:end_x]
54
+ #masks.append(sub_img)
55
+ centers.append((start_x,start_y))
56
+ i+=1
57
+ end=time()
58
+ print('get_regionprops_time: ',end-start)
59
+
60
+ #return masks,centers
61
+
62
+
63
+ def connected_components_skimage(images,crop_size=15,NrPixels = 2048):
64
+ window=int(crop_size/2)
65
+ masks =[]
66
+ centers =[]
67
+ ccs =[]
68
+ start = time()
69
+ for image in images:
70
+ cc_out = measure.label(image.as_type(int))
71
+ ccs.append(cc_out)
72
+ end = time()
73
+ print('cc_time', end - start)
74
+ return ccs
75
+
76
+ def clean_patch(p, center):
77
+ w, h = p.shape
78
+ cc = measure.label(p > 0)
79
+ if cc.max() == 1:
80
+ return p
81
+
82
+ # logging.warn(f"{cc.max()} peaks located in a patch")
83
+ lmin = np.inf
84
+ cc_lmin = None
85
+ for _c in range(1, cc.max()+1):
86
+ lmax = peak_local_max(p * (cc==_c), min_distance=1)
87
+ if lmax.shape[0] == 0:continue # single pixel component
88
+ lc = lmax.mean(axis=0)
89
+ dist = ((lc - center)**2).sum()
90
+ if dist < lmin:
91
+ cc_lmin = _c
92
+ lmin = dist
93
+ return p * (cc == cc_lmin)
94
+
95
+ class FrameReaderDataset(Dataset):
96
+ def __init__(self, ffile, dfile,NrPixels=2048, nFrames=1440, nrFiles=1, thresh = 100, fHead = 8192):
97
+ print("dark file:",dfile)
98
+ print("frames file:",ffile)
99
+ self.ffile = ffile
100
+ self.dark = np.zeros(NrPixels*NrPixels)
101
+ if os.path.exists(dfile):
102
+ darkf = open(dfile,'rb')
103
+ nFramesDark = int((os.path.getsize(dfile) - 8192) / (2*NrPixels*NrPixels))
104
+ darkf.seek(8192,os.SEEK_SET)
105
+ for nr in range(nFramesDark):
106
+ self.dark += np.fromfile(darkf,dtype=np.uint16,count=(NrPixels*NrPixels))
107
+ self.dark = self.dark.astype(float)
108
+ self.dark /= nFramesDark
109
+ self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
110
+ else:
111
+ self.dark = np.zeros((NrPixels,NrPixels)).astype(float)
112
+
113
+ self.frames = []
114
+ self.len = nFrames
115
+ for fnr in range(nrFiles):
116
+ startFrameNr = (nFrames//nrFiles)*fnr
117
+ endFrameNr = (nFrames//nrFiles)*(fnr+1)
118
+ f = open(ffile,'rb')
119
+ f.seek(fHead,os.SEEK_SET)
120
+ for frameNr in range(startFrameNr,endFrameNr):
121
+ self.thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
122
+ self.thisFrame = np.reshape(self.thisFrame,(NrPixels,NrPixels))
123
+ self.thisFrame = self.thisFrame.astype(float)
124
+ self.thisFrame = self.thisFrame - self.dark
125
+ self.thisFrame[self.thisFrame < thresh] = 0
126
+ self.frames.append(self.thisFrame)
127
+
128
+ def get_frames(self):
129
+ return np.array(self.frames)
130
+
131
+ def write_frames_torch(self):
132
+ f_name = self.ffile.split('/')[-1]
133
+ torch.save(self.frames,'frames_%s.pt' %f_name.split('.ge3')[0])
134
+
135
+ def write_frames_numpy(self):
136
+ f_name = self.ffile.split('/')[-1]
137
+ np.save('frames_%s.npy' %f_name.split('.ge3')[0],self.frames)
138
+
139
+ def get_peaks_torch(self, psz=15):
140
+ peaks = connected_components_torch(np.array(self.frames))
141
+ return peaks
142
+
143
+ def get_peaks_skimage(self, psz=15):
144
+ peaks = connected_components_skimage(self.frames)
145
+ return peaks
146
+
147
+ def write_peaks_torch(self):
148
+ f_name = self.ffile.split('/')[-1]
149
+ peaks = self.get_peaks_skimage()
150
+ torch.save(peaks,'peaks_%s.pt' %f_name.split('.ge3')[0])
151
+
152
+ def write_peaks_numpy(self):
153
+ f_name = self.ffile.split('/')[-1]
154
+ peaks = self.get_peaks_skimage()
155
+ np.save('peaks_%s.npy' %f_name.split('.ge3')[0],peaks)
156
+
157
+ class BraggNNDataset(Dataset):
158
+ def __init__(self, pfile=None, ffile=None, ge_dataset=False, ge_ffile=None, ge_dfile=None, psz=15, rnd_shift=0, use='train', train_frac=0.8):
159
+ self.psz = psz
160
+ self.rnd_shift = rnd_shift
161
+
162
+
163
+ with h5py.File(pfile, "r") as h5fd:
164
+ if use == 'train':
165
+ sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
166
+ elif use == 'validation':
167
+ sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
168
+ else:
169
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
170
+
171
+ mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
172
+ mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
173
+
174
+ self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
175
+ self.peak_row = h5fd['peak_row'][sti:edi][mask]
176
+ self.peak_col = h5fd['peak_col'][sti:edi][mask]
177
+
178
+ self.fidx_base = self.peak_fidx.min()
179
+ # only loaded frames that will be used
180
+ if ge_dataset is True:
181
+ self.frames = FrameReaderDataset(ge_ffile,ge_dfile).get_frames()#[self.peak_fidx.min():self.peak_fidx.max()+1]
182
+ self.len = self.peak_fidx.shape[0]
183
+ print(self.len)
184
+ else:
185
+ with h5py.File(ffile, "r") as h5fd:
186
+ self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]
187
+ self.len = self.peak_fidx.shape[0]
188
+
189
+
190
+ def __getitem__(self, idx):
191
+ _frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
192
+ if self.rnd_shift > 0:
193
+ row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
194
+ col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
195
+ else:
196
+ row_shift, col_shift = 0, 0
197
+ prow_rnd = int(self.peak_row[idx]) + row_shift
198
+ pcol_rnd = int(self.peak_col[idx]) + col_shift
199
+
200
+ row_base = max(0, prow_rnd-self.psz//2)
201
+ col_base = max(0, pcol_rnd-self.psz//2 )
202
+
203
+ crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
204
+ col_base:(pcol_rnd + self.psz//2 + self.psz%2)]
205
+ # if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
206
+ if crop_img.size != self.psz ** 2:
207
+ c_pad_l = (self.psz - crop_img.shape[1]) // 2
208
+ c_pad_r = self.psz - c_pad_l - crop_img.shape[1]
209
+
210
+ r_pad_t = (self.psz - crop_img.shape[0]) // 2
211
+ r_pad_b = self.psz - r_pad_t - crop_img.shape[0]
212
+
213
+ logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
214
+ crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
215
+ else:
216
+ c_pad_l, r_pad_t = 0 ,0
217
+
218
+ _center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
219
+ crop_img = clean_patch(crop_img, _center)
220
+ if crop_img.max() != crop_img.min():
221
+ _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
222
+ feature = (crop_img - _min) / (_max - _min)
223
+ else:
224
+ logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
225
+ feature = crop_img
226
+
227
+ px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
228
+ py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
229
+
230
+ return feature[np.newaxis], np.array([px, py]).astype(np.float32)
231
+
232
+ def __len__(self):
233
+ return self.len
234
+
235
+
236
+ class MidasDataset(Dataset):
237
+ def __init__(self, mfile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
238
+ self.psz = psz
239
+ self.rnd_shift = rnd_shift
240
+ with h5py.File(mfile, "r") as h5fd:
241
+ if use == 'train':
242
+ sti, edi = 0, int(train_frac * len(h5fd['peakLoc']))
243
+ elif use == 'validation':
244
+ sti, edi = int(train_frac * len(h5fd['peakLoc'])), None
245
+ else:
246
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
247
+
248
+ npeaks = []
249
+ mask = []
250
+ for loc in h5fd['peakLoc'][sti:edi]:
251
+ npeaks.append(len(loc))
252
+ mask.append(len(loc)==2)
253
+
254
+ #mask = npeaks[sti:edi] == 2 # use only single-peak patches
255
+ #mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
256
+
257
+ self.npeaks = npeaks
258
+ self.peak_locs = h5fd["peakLoc"][sti:edi][mask]
259
+ self.peak_row = [loc[0] for loc in self.peak_locs]
260
+ self.peak_col = [loc[1] for loc in self.peak_locs]
261
+ self.deviations = np.zeros(shape=(len(self.peak_locs),))
262
+ self.diffY = h5fd["diffY"][sti:edi][mask]
263
+ self.diffZ = h5fd["diffZ"][sti:edi][mask]
264
+ self.peak_fidx = np.zeros(shape=(len(self.peak_locs),))
265
+
266
+ self.crop_img = h5fd['patch'][sti:edi][mask]
267
+ self.len = len(self.peak_locs)#.shape[0]
268
+
269
+ def __getitem__(self, idx):
270
+ crop_img = self.crop_img[idx]
271
+
272
+ row_shift, col_shift = 0, 0
273
+ c_pad_l, r_pad_t = 0 ,0
274
+ prow_rnd = int(self.peak_row[idx]) + row_shift
275
+ pcol_rnd = int(self.peak_col[idx]) + col_shift
276
+
277
+ row_base = max(0, prow_rnd-self.psz//2)
278
+ col_base = max(0, pcol_rnd-self.psz//2)
279
+
280
+ if crop_img.max() != crop_img.min():
281
+ _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
282
+ feature = (crop_img - _min) / (_max - _min)
283
+ else:
284
+ #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
285
+ feature = crop_img
286
+
287
+ px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
288
+ py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
289
+
290
+ return feature[np.newaxis], np.array([px, py]).astype(np.float32)
291
+
292
+ def __len__(self):
293
+ return self.len
294
+
295
+ class PatchWiseDataset(Dataset):
296
+ def __init__(self, pfile=None, ffile=None, ge_dataset=False, ge_ffile=None, ge_dfile=None, psz=15, rnd_shift=0, use='train', train_frac=0.8):
297
+ self.ge_dataset = ge_dataset
298
+ self.psz = psz
299
+ self.rnd_shift = rnd_shift
300
+ if ge_dataset is True:
301
+ self.peaks = FrameReaderDataset(ge_ffile,ge_dfile).get_peaks_skimage()
302
+ self.len = len(self.peaks)
303
+ print(self.len)
304
+ if use == 'train':
305
+ sti, edi = 0, int(train_frac * self.len)
306
+ elif use == 'validation':
307
+ sti, edi = int(train_frac * self.len), None
308
+ else:
309
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
310
+ self.crop_img = self.peaks[sti:edi]
311
+ else:
312
+ with h5py.File(pfile, "r") as h5fd:
313
+ if use == 'train':
314
+ sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
315
+ elif use == 'validation':
316
+ sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
317
+ else:
318
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
319
+
320
+ mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
321
+ mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
322
+
323
+ self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
324
+ self.peak_row = h5fd['peak_row'][sti:edi][mask]
325
+ self.peak_col = h5fd['peak_col'][sti:edi][mask]
326
+ self.fidx_base = self.peak_fidx.min()
327
+ with h5py.File(ffile, 'r') as h5fd:
328
+ if use == 'train':
329
+ sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
330
+ elif use == 'validation':
331
+ sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
332
+ else:
333
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
334
+ self.crop_img = h5fd['frames'][sti:edi]
335
+ self.len = self.peak_fidx.shape[0]
336
+
337
+ def __getitem__(self, idx):
338
+ ge_dataset = self.ge_dataset
339
+ print(idx)
340
+ crop_img = self.crop_img[idx]
341
+
342
+ if ge_dataset is True:
343
+ return crop_img
344
+ else:
345
+ row_shift, col_shift = 0, 0
346
+ c_pad_l, r_pad_t = 0 ,0
347
+ prow_rnd = int(self.peak_row[idx]) + row_shift
348
+ pcol_rnd = int(self.peak_col[idx]) + col_shift
349
+
350
+ row_base = max(0, prow_rnd-self.psz//2)
351
+ col_base = max(0, pcol_rnd-self.psz//2)
352
+
353
+ if crop_img.max() != crop_img.min():
354
+ _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
355
+ feature = (crop_img - _min) / (_max - _min)
356
+ else:
357
+ #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
358
+ feature = crop_img
359
+
360
+ px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
361
+ py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
362
+
363
+ return feature[np.newaxis], np.array([px, py]).astype(np.float32)
364
+
365
+ def __len__(self):
366
+ return self.len
367
+