BNNPreprocess / ds.py
Dennis Trujillo
best and latest PatchDataset
55ab5b2
import os
import h5py
import torch
import numpy as np
from numba import jit
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops
@jit
def process_frame(frame, dark, thresh):
frame = frame - dark
frame[frame < thresh] = 0
frame = frame.astype(int)
return frame
def normalize_patch(patch):
patch = patch.astype(float)
_min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
feature = (patch - _min) / (_max- _min)
return feature
class FrameDataset(Dataset):
def __init__(self, ffile, dfile, nFrames, batch_size, NrPixels=2048, thresh=100, fHead=8192, window=7):
self.NrPixels = NrPixels
self.batch_size = batch_size
# Read dark frame
with open(dfile, 'rb') as darkf:
darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
self.dark = self.dark.astype(float)
darkf.close()
# Read frames
self.frames = []
self.length = nFrames
self.xy_positions = []
self.patches=[]
with open(ffile, 'rb') as f:
for fNr in range(1,nFrames+1):
BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
f.seek(BytesToSkip,os.SEEK_SET)
thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
thisFrame = thisFrame.astype(float)
thisFrame = process_frame(thisFrame, self.dark, thresh)
thisFrame2 = np.copy(thisFrame)
thisFrame2[thisFrame2>0] = 1
labels = label(thisFrame2)
regions = regionprops(labels)
self.patches = []
for prop_nr,props in enumerate(regions):
if props.area < 4 or props.area > 150:
continue
y0,x0 = props.centroid
start_x = int(x0)-window
end_x = int(x0)+window+1
start_y = int(y0)-window
end_y = int(y0)+window+1
if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
continue
sub_img = thisFrame[start_y:end_y,start_x:end_x]
self.patches.append(normalize_patch(sub_img))
self.xy_positions.append([start_y,start_x])
print(np.shape(self.patches))
f.close()
def __iter__(self):
self.batch_start = 0
self.batch_end = self.batch_size
return self
def __next__(self):
if self.batch_end > self.length:
self.batch_start = 0
self.batch_end = self.batch_size
raise StopIteration
else:
p_batch = self.p_batch[self.batch_start:self.batch_end]
xy_batch = self.xy_batch[self.batch_start:self.batch_end]
self.batch_start += self.batch_size
self.batch_end += self.batch_size
return p_batch, xy_batch
def __len__(self):
return self.length
def __getitem__(self, index):
p_batch = self.patches[index*self.batch_size:(index+1)*self.batch_size]
xy_batch = self.xy_positions[index*self.batch_size:(index+1)*self.batch_size]
return p_batch, xy_batch