Dennis Trujillo commited on
Commit
55ab5b2
·
1 Parent(s): f95931d

best and latest PatchDataset

Browse files
Files changed (2) hide show
  1. ds.py +95 -0
  2. ds1.py +76 -0
ds.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import torch
4
+ import numpy as np
5
+ from numba import jit
6
+ from skimage import measure
7
+ from torchvision import transforms
8
+ from torch.utils.data import Dataset
9
+ from skimage.measure import label, regionprops
10
+
11
+ @jit
12
+ def process_frame(frame, dark, thresh):
13
+ frame = frame - dark
14
+ frame[frame < thresh] = 0
15
+ frame = frame.astype(int)
16
+ return frame
17
+
18
+ def normalize_patch(patch):
19
+ patch = patch.astype(float)
20
+ _min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
21
+ feature = (patch - _min) / (_max- _min)
22
+ return feature
23
+
24
+
25
+ class FrameDataset(Dataset):
26
+ def __init__(self, ffile, dfile, nFrames, batch_size, NrPixels=2048, thresh=100, fHead=8192, window=7):
27
+ self.NrPixels = NrPixels
28
+ self.batch_size = batch_size
29
+
30
+ # Read dark frame
31
+ with open(dfile, 'rb') as darkf:
32
+ darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
33
+ self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
34
+ self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
35
+ self.dark = self.dark.astype(float)
36
+ darkf.close()
37
+
38
+ # Read frames
39
+ self.frames = []
40
+ self.length = nFrames
41
+ self.xy_positions = []
42
+ self.patches=[]
43
+ with open(ffile, 'rb') as f:
44
+ for fNr in range(1,nFrames+1):
45
+ BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
46
+ f.seek(BytesToSkip,os.SEEK_SET)
47
+ thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
48
+ thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
49
+ thisFrame = thisFrame.astype(float)
50
+ thisFrame = process_frame(thisFrame, self.dark, thresh)
51
+ thisFrame2 = np.copy(thisFrame)
52
+ thisFrame2[thisFrame2>0] = 1
53
+ labels = label(thisFrame2)
54
+ regions = regionprops(labels)
55
+ self.patches = []
56
+ for prop_nr,props in enumerate(regions):
57
+ if props.area < 4 or props.area > 150:
58
+ continue
59
+ y0,x0 = props.centroid
60
+ start_x = int(x0)-window
61
+ end_x = int(x0)+window+1
62
+ start_y = int(y0)-window
63
+ end_y = int(y0)+window+1
64
+ if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
65
+ continue
66
+ sub_img = thisFrame[start_y:end_y,start_x:end_x]
67
+ self.patches.append(normalize_patch(sub_img))
68
+ self.xy_positions.append([start_y,start_x])
69
+ print(np.shape(self.patches))
70
+ f.close()
71
+
72
+ def __iter__(self):
73
+ self.batch_start = 0
74
+ self.batch_end = self.batch_size
75
+ return self
76
+
77
+ def __next__(self):
78
+ if self.batch_end > self.length:
79
+ self.batch_start = 0
80
+ self.batch_end = self.batch_size
81
+ raise StopIteration
82
+ else:
83
+ p_batch = self.p_batch[self.batch_start:self.batch_end]
84
+ xy_batch = self.xy_batch[self.batch_start:self.batch_end]
85
+ self.batch_start += self.batch_size
86
+ self.batch_end += self.batch_size
87
+ return p_batch, xy_batch
88
+
89
+ def __len__(self):
90
+ return self.length
91
+
92
+ def __getitem__(self, index):
93
+ p_batch = self.patches[index*self.batch_size:(index+1)*self.batch_size]
94
+ xy_batch = self.xy_positions[index*self.batch_size:(index+1)*self.batch_size]
95
+ return p_batch, xy_batch
ds1.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import torch
4
+ import numpy as np
5
+ from numba import jit
6
+ from skimage import measure
7
+ from torchvision import transforms
8
+ from torch.utils.data import Dataset
9
+ from skimage.measure import label, regionprops
10
+
11
+ @jit
12
+ def process_frame(frame, dark, thresh):
13
+ frame = frame - dark
14
+ frame[frame < thresh] = 0
15
+ frame = frame.astype(int)
16
+ return frame
17
+
18
+ def normalize_patch(patch):
19
+ patch = patch.astype(float)
20
+ _min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
21
+ feature = (patch - _min) / (_max- _min)
22
+ return feature
23
+
24
+
25
+ class PatchDataset(Dataset):
26
+ def __init__(self, ffile, dfile, nFrames, NrPixels=2048, thresh=100, fHead=8192, window=7):
27
+ self.NrPixels = NrPixels
28
+
29
+ # Read dark frame
30
+ with open(dfile, 'rb') as darkf:
31
+ darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
32
+ self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
33
+ self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
34
+ self.dark = self.dark.astype(float)
35
+ darkf.close()
36
+
37
+ # Read frames
38
+ self.frames = []
39
+ self.length = 0
40
+ self.xy_positions = []
41
+ self.patches=[]
42
+ with open(ffile, 'rb') as f:
43
+ for fNr in range(1,nFrames+1):
44
+ BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
45
+ f.seek(BytesToSkip,os.SEEK_SET)
46
+ thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
47
+ thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
48
+ thisFrame = thisFrame.astype(float)
49
+ thisFrame = process_frame(thisFrame, self.dark, thresh)
50
+ thisFrame2 = np.copy(thisFrame)
51
+ thisFrame2[thisFrame2>0] = 1
52
+ labels = label(thisFrame2)
53
+ regions = regionprops(labels)
54
+ self.patches = []
55
+ for prop_nr,props in enumerate(regions):
56
+ if props.area < 4 or props.area > 150:
57
+ continue
58
+ y0,x0 = props.centroid
59
+ start_x = int(x0)-window
60
+ end_x = int(x0)+window+1
61
+ start_y = int(y0)-window
62
+ end_y = int(y0)+window+1
63
+ if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
64
+ continue
65
+ sub_img = thisFrame[start_y:end_y,start_x:end_x]
66
+ self.patches.append(normalize_patch(sub_img))
67
+ self.xy_positions.append([start_y,start_x])
68
+ self.length += len(self.patches)
69
+ f.close()
70
+
71
+ def __len__(self):
72
+ return self.length
73
+
74
+ def __getitem__(self, index):
75
+ return self.patches[index], self.xy_positions[index]
76
+