Dennis Trujillo commited on
Commit
f95931d
·
1 Parent(s): 9f5ee14

simplied FrameReaderDataset into an actual torch Dataset

Browse files
Files changed (1) hide show
  1. dataset.py +84 -366
dataset.py CHANGED
@@ -1,376 +1,94 @@
1
- from torch.utils.data import Dataset
 
 
2
  import numpy as np
3
- import h5py, torch, random, logging
4
- from skimage.feature import peak_local_max
5
  from skimage import measure
6
- from skimage.measure import label, regionprops
7
- import os
8
- #from cc_torch import connected_components_labeling
9
  from torchvision import transforms
10
- from time import time
11
- import gc
12
- #from torch.utils.data import TensorDataset,Dataloader
13
-
14
- def normalize_patches(patches):
15
- normalized_patches = []
16
- for patch in patches:
17
- # normalize the patch by subtracting the mean and dividing by the standard deviation
18
- normalized_patch = (patch - patch.mean()) / patch.std()
19
- normalized_patches.append(normalized_patch)
20
- return normalized_patches
21
-
22
-
23
- def connected_components_torch(images,crop_size=15,NrPixels = 2048):
24
- window=int(crop_size/2)
25
- start = torch.cuda.Event(enable_timing=True)
26
- end = torch.cuda.Event(enable_timing=True)
27
-
28
- ccs =[]
29
-
30
- start.record()
31
-
32
- for image in images:
33
- cc_out = connected_components_labeling(image).cpu().numpy()
34
- ccs.append(cc_out)
35
-
36
- end.record()
37
- torch.cuda.synchronize()
38
- print('cc_time: ', start.elapsed_time(end)/1000)
39
-
40
- return ccs
41
-
42
- def region_props(ccs,images,crop_size=15,NrPixels = 2048):
43
- window=int(crop_size/2)
44
-
45
- masks =[]
46
- centers =[]
47
-
48
- i=0
49
- start = time()
50
- for cc in ccs:
51
- for region_nr,region in enumerate(regionprops(cc)):
52
- if region.area > 4 or region.area < 150:
53
- x,y = region.centroid
54
- start_x = int(x)-window
55
- end_x = int(x)+window+1
56
- start_y = int(y)-window
57
- end_y = int(y)+window+1
58
- if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
59
- continue
60
- #sub_img = np.copy(images[i])
61
- #sub_img[ccs != region_nr+1] = 0
62
- #sub_img = sub_img[start_y:end_y,start_x:end_x]
63
- #masks.append(sub_img)
64
- centers.append((start_x,start_y))
65
- i+=1
66
- end=time()
67
- print('get_regionprops_time: ',end-start)
68
-
69
- #return masks,centers
70
-
71
-
72
- def connected_components_skimage(images,crop_size=15,NrPixels = 2048):
73
- window=int(crop_size/2)
74
- masks =[]
75
- centers =[]
76
- ccs =[]
77
- start = time()
78
- for image in images:
79
- cc_out = measure.label(image.as_type(int))
80
- ccs.append(cc_out)
81
- end = time()
82
- print('cc_time', end - start)
83
- return ccs
84
-
85
- def clean_patch(p, center):
86
- w, h = p.shape
87
- cc = measure.label(p > 0)
88
- if cc.max() == 1:
89
- return p
90
-
91
- # logging.warn(f"{cc.max()} peaks located in a patch")
92
- lmin = np.inf
93
- cc_lmin = None
94
- for _c in range(1, cc.max()+1):
95
- lmax = peak_local_max(p * (cc==_c), min_distance=1)
96
- if lmax.shape[0] == 0:continue # single pixel component
97
- lc = lmax.mean(axis=0)
98
- dist = ((lc - center)**2).sum()
99
- if dist < lmin:
100
- cc_lmin = _c
101
- lmin = dist
102
- return p * (cc == cc_lmin)
103
 
104
- class FrameReaderDataset(Dataset):
105
- def __init__(self, ffile, dfile,NrPixels=2048, nFrames=1440, nrFiles=1, thresh = 100, fHead = 8192):
106
- print("dark file:",dfile)
107
- print("frames file:",ffile)
108
- self.ffile = ffile
109
- self.dark = np.zeros(NrPixels*NrPixels)
110
- if os.path.exists(dfile):
111
- darkf = open(dfile,'rb')
112
- nFramesDark = int((os.path.getsize(dfile) - 8192) / (2*NrPixels*NrPixels))
113
- darkf.seek(8192,os.SEEK_SET)
114
- for nr in range(nFramesDark):
115
- self.dark += np.fromfile(darkf,dtype=np.uint16,count=(NrPixels*NrPixels))
116
- self.dark = self.dark.astype(float)
117
- self.dark /= nFramesDark
118
  self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
119
- else:
120
- self.dark = np.zeros((NrPixels,NrPixels)).astype(float)
121
-
122
  self.frames = []
123
- self.len = nFrames
124
- for fnr in range(nrFiles):
125
- startFrameNr = (nFrames//nrFiles)*fnr
126
- endFrameNr = (nFrames//nrFiles)*(fnr+1)
127
- f = open(ffile,'rb')
128
- f.seek(fHead,os.SEEK_SET)
129
- for frameNr in range(startFrameNr,endFrameNr):
130
- self.thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
131
- self.thisFrame = np.reshape(self.thisFrame,(NrPixels,NrPixels))
132
- self.thisFrame = self.thisFrame.astype(float)
133
- self.thisFrame = self.thisFrame - self.dark
134
- self.thisFrame[self.thisFrame < thresh] = 0
135
- self.frames.append(self.thisFrame)
136
-
137
- def get_frames(self):
138
- return np.array(self.frames)
139
-
140
- def write_frames_torch(self):
141
- f_name = self.ffile.split('/')[-1]
142
- torch.save(self.frames,'frames_%s.pt' %f_name.split('.ge3')[0])
143
-
144
- def write_frames_numpy(self):
145
- f_name = self.ffile.split('/')[-1]
146
- np.save('frames_%s.npy' %f_name.split('.ge3')[0],self.frames)
147
-
148
- def get_peaks_torch(self, psz=15):
149
- peaks = connected_components_torch(np.array(self.frames))
150
- return peaks
151
-
152
- def get_peaks_skimage(self, psz=15):
153
- peaks = connected_components_skimage(self.frames)
154
- return peaks
155
 
156
- def write_peaks_torch(self):
157
- f_name = self.ffile.split('/')[-1]
158
- peaks = self.get_peaks_skimage()
159
- torch.save(peaks,'peaks_%s.pt' %f_name.split('.ge3')[0])
160
-
161
- def write_peaks_numpy(self):
162
- f_name = self.ffile.split('/')[-1]
163
- peaks = self.get_peaks_skimage()
164
- np.save('peaks_%s.npy' %f_name.split('.ge3')[0],peaks)
165
-
166
- class BraggNNDataset(Dataset):
167
- def __init__(self, pfile=None, ffile=None, ge_dataset=False, ge_ffile=None, ge_dfile=None, psz=15, rnd_shift=0, use='train', train_frac=0.8):
168
- self.psz = psz
169
- self.rnd_shift = rnd_shift
170
-
171
-
172
- with h5py.File(pfile, "r") as h5fd:
173
- if use == 'train':
174
- sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
175
- elif use == 'validation':
176
- sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
177
- else:
178
- logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
179
-
180
- mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
181
- mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
182
-
183
- self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
184
- self.peak_row = h5fd['peak_row'][sti:edi][mask]
185
- self.peak_col = h5fd['peak_col'][sti:edi][mask]
186
-
187
- self.fidx_base = self.peak_fidx.min()
188
- # only loaded frames that will be used
189
- if ge_dataset is True:
190
- self.frames = FrameReaderDataset(ge_ffile,ge_dfile).get_frames()#[self.peak_fidx.min():self.peak_fidx.max()+1]
191
- self.len = self.peak_fidx.shape[0]
192
- print(self.len)
193
- else:
194
- with h5py.File(ffile, "r") as h5fd:
195
- self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]
196
- self.len = self.peak_fidx.shape[0]
197
-
198
-
199
- def __getitem__(self, idx):
200
- _frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
201
- if self.rnd_shift > 0:
202
- row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
203
- col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
204
- else:
205
- row_shift, col_shift = 0, 0
206
- prow_rnd = int(self.peak_row[idx]) + row_shift
207
- pcol_rnd = int(self.peak_col[idx]) + col_shift
208
-
209
- row_base = max(0, prow_rnd-self.psz//2)
210
- col_base = max(0, pcol_rnd-self.psz//2 )
211
-
212
- crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
213
- col_base:(pcol_rnd + self.psz//2 + self.psz%2)]
214
- # if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
215
- if crop_img.size != self.psz ** 2:
216
- c_pad_l = (self.psz - crop_img.shape[1]) // 2
217
- c_pad_r = self.psz - c_pad_l - crop_img.shape[1]
218
-
219
- r_pad_t = (self.psz - crop_img.shape[0]) // 2
220
- r_pad_b = self.psz - r_pad_t - crop_img.shape[0]
221
-
222
- logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
223
- crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
224
- else:
225
- c_pad_l, r_pad_t = 0 ,0
226
-
227
- _center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
228
- crop_img = clean_patch(crop_img, _center)
229
- if crop_img.max() != crop_img.min():
230
- _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
231
- feature = (crop_img - _min) / (_max - _min)
232
- else:
233
- logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
234
- feature = crop_img
235
-
236
- px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
237
- py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
238
-
239
- return feature[np.newaxis], np.array([px, py]).astype(np.float32)
240
-
241
- def __len__(self):
242
- return self.len
243
-
244
-
245
- class MidasDataset(Dataset):
246
- def __init__(self, mfile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
247
- self.psz = psz
248
- self.rnd_shift = rnd_shift
249
- with h5py.File(mfile, "r") as h5fd:
250
- if use == 'train':
251
- sti, edi = 0, int(train_frac * len(h5fd['peakLoc']))
252
- elif use == 'validation':
253
- sti, edi = int(train_frac * len(h5fd['peakLoc'])), None
254
- else:
255
- logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
256
-
257
- npeaks = []
258
- mask = []
259
- for loc in h5fd['peakLoc'][sti:edi]:
260
- npeaks.append(len(loc))
261
- mask.append(len(loc)==2)
262
-
263
- #mask = npeaks[sti:edi] == 2 # use only single-peak patches
264
- #mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
265
-
266
- self.npeaks = npeaks
267
- self.peak_locs = h5fd["peakLoc"][sti:edi][mask]
268
- self.peak_row = [loc[0] for loc in self.peak_locs]
269
- self.peak_col = [loc[1] for loc in self.peak_locs]
270
- self.deviations = np.zeros(shape=(len(self.peak_locs),))
271
- self.diffY = h5fd["diffY"][sti:edi][mask]
272
- self.diffZ = h5fd["diffZ"][sti:edi][mask]
273
- self.peak_fidx = np.zeros(shape=(len(self.peak_locs),))
274
-
275
- self.crop_img = h5fd['patch'][sti:edi][mask]
276
- self.len = len(self.peak_locs)#.shape[0]
277
-
278
- def __getitem__(self, idx):
279
- crop_img = self.crop_img[idx]
280
-
281
- row_shift, col_shift = 0, 0
282
- c_pad_l, r_pad_t = 0 ,0
283
- prow_rnd = int(self.peak_row[idx]) + row_shift
284
- pcol_rnd = int(self.peak_col[idx]) + col_shift
285
-
286
- row_base = max(0, prow_rnd-self.psz//2)
287
- col_base = max(0, pcol_rnd-self.psz//2)
288
-
289
- if crop_img.max() != crop_img.min():
290
- _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
291
- feature = (crop_img - _min) / (_max - _min)
292
  else:
293
- #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
294
- feature = crop_img
295
-
296
- px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
297
- py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
298
-
299
- return feature[np.newaxis], np.array([px, py]).astype(np.float32)
300
-
301
- def __len__(self):
302
- return self.len
303
-
304
- class PatchWiseDataset(Dataset):
305
- def __init__(self, pfile=None, ffile=None, ge_dataset=False, ge_ffile=None, ge_dfile=None, psz=15, rnd_shift=0, use='train', train_frac=0.8):
306
- self.ge_dataset = ge_dataset
307
- self.psz = psz
308
- self.rnd_shift = rnd_shift
309
- if ge_dataset is True:
310
- self.peaks = FrameReaderDataset(ge_ffile,ge_dfile).get_peaks_skimage()
311
- self.len = len(self.peaks)
312
- print(self.len)
313
- if use == 'train':
314
- sti, edi = 0, int(train_frac * self.len)
315
- elif use == 'validation':
316
- sti, edi = int(train_frac * self.len), None
317
- else:
318
- logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
319
- self.crop_img = self.peaks[sti:edi]
320
- else:
321
- with h5py.File(pfile, "r") as h5fd:
322
- if use == 'train':
323
- sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
324
- elif use == 'validation':
325
- sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
326
- else:
327
- logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
328
-
329
- mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
330
- mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
331
-
332
- self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
333
- self.peak_row = h5fd['peak_row'][sti:edi][mask]
334
- self.peak_col = h5fd['peak_col'][sti:edi][mask]
335
- self.fidx_base = self.peak_fidx.min()
336
- with h5py.File(ffile, 'r') as h5fd:
337
- if use == 'train':
338
- sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
339
- elif use == 'validation':
340
- sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
341
- else:
342
- logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
343
- self.crop_img = h5fd['frames'][sti:edi]
344
- self.len = self.peak_fidx.shape[0]
345
 
346
- def __getitem__(self, idx):
347
- ge_dataset = self.ge_dataset
348
- print(idx)
349
- crop_img = self.crop_img[idx]
350
-
351
- if ge_dataset is True:
352
- return crop_img
353
- else:
354
- row_shift, col_shift = 0, 0
355
- c_pad_l, r_pad_t = 0 ,0
356
- prow_rnd = int(self.peak_row[idx]) + row_shift
357
- pcol_rnd = int(self.peak_col[idx]) + col_shift
358
-
359
- row_base = max(0, prow_rnd-self.psz//2)
360
- col_base = max(0, pcol_rnd-self.psz//2)
361
-
362
- if crop_img.max() != crop_img.min():
363
- _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
364
- feature = (crop_img - _min) / (_max - _min)
365
- else:
366
- #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
367
- feature = crop_img
368
-
369
- px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
370
- py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
371
-
372
- return feature[np.newaxis], np.array([px, py]).astype(np.float32)
373
-
374
  def __len__(self):
375
- return self.len
376
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import torch
4
  import numpy as np
 
 
5
  from skimage import measure
 
 
 
6
  from torchvision import transforms
7
+ from torch.utils.data import Dataset
8
+ from skimage.measure import label, regionprops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class FrameDataset(Dataset):
11
+ def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192):
12
+ self.NrPixels = NrPixels
13
+ self.batch_size = batch_size
14
+
15
+ # Read dark frame
16
+ with open(dfile, 'rb') as darkf:
17
+ darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
18
+ self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
 
 
 
 
 
19
  self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
20
+ self.dark = self.dark.astype(float)
21
+
22
+ # Read frames
23
  self.frames = []
24
+ self.length = nFrames
25
+ with open(ffile, 'rb') as f:
26
+ for _ in range(1, nFrames+1): # Skip first frame
27
+ BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
28
+ f.seek(BytesToSkip, os.SEEK_SET)
29
+ this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels))
30
+ this_frame = np.reshape(this_frame, (NrPixels, NrPixels))
31
+ this_frame = this_frame.astype(float)
32
+ this_frame = this_frame - self.dark
33
+ this_frame[this_frame < thresh] = 0
34
+ thisFrame = thisFrame.astype(int)
35
+ self.frames.append(this_frame)
36
+
37
+ def __iter__(self):
38
+ self.batch_start = 0
39
+ self.batch_end = self.batch_size
40
+ return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def __next__(self):
43
+ if self.batch_end > self.length:
44
+ self.batch_start = 0
45
+ self.batch_end = self.batch_size
46
+ raise StopIteration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
+ f_batch = self.f_data[self.batch_start:self.batch_end]
49
+ d_batch = self.d_data[self.batch_start:self.batch_end]
50
+ self.batch_start += self.batch_size
51
+ self.batch_end += self.batch_size
52
+ return f_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def __len__(self):
55
+ return self.length
56
+
57
+ def __getitem__(self, index):
58
+ f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size]
59
+ return f_batch
60
+
61
+ def get_peaks_skimage(self, frames):
62
+ regions = []
63
+ for frame in frames:
64
+ frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels)
65
+ labels = measure.label(frame_array)
66
+ regions = regionprops(labels)
67
+ for prop_nr,props in enumerate(regions):
68
+ if props.area < 4 or props.area > 150:
69
+ continue
70
+ y0,x0 = props.centroid
71
+ start_x = int(x0)-window
72
+ end_x = int(x0)+window+1
73
+ start_y = int(y0)-window
74
+ end_y = int(y0)+window+1
75
+ if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
76
+ continue
77
+ sub_img = np.copy(thisFrame)
78
+ sub_img[labels != prop_nr+1] = 0
79
+ sub_img = sub_img[start_y:end_y,start_x:end_x]
80
+ patches.append(sub_img)
81
+ xy_positions.append([start_y,start_x])
82
+ patches = np.array(patches)
83
+ xy_positions = np.array(xy_positions)
84
+ return patches
85
+
86
+ def normalize_patches(self, patches):
87
+ normalized_patches = []
88
+ for patch in patches:
89
+ patch = patch.astype(float)
90
+ patch /= patch.max()
91
+ patch *= 255
92
+ patch = patch.astype(int)
93
+ normalized_patches.append(patch)
94
+ return normalized_patches