BNNPreprocess / dataset.py
Dennis Trujillo
simplied FrameReaderDataset into an actual torch Dataset
f95931d
raw
history blame
3.71 kB
import os
import h5py
import torch
import numpy as np
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops
class FrameDataset(Dataset):
def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192):
self.NrPixels = NrPixels
self.batch_size = batch_size
# Read dark frame
with open(dfile, 'rb') as darkf:
darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
self.dark = self.dark.astype(float)
# Read frames
self.frames = []
self.length = nFrames
with open(ffile, 'rb') as f:
for _ in range(1, nFrames+1): # Skip first frame
BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
f.seek(BytesToSkip, os.SEEK_SET)
this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels))
this_frame = np.reshape(this_frame, (NrPixels, NrPixels))
this_frame = this_frame.astype(float)
this_frame = this_frame - self.dark
this_frame[this_frame < thresh] = 0
thisFrame = thisFrame.astype(int)
self.frames.append(this_frame)
def __iter__(self):
self.batch_start = 0
self.batch_end = self.batch_size
return self
def __next__(self):
if self.batch_end > self.length:
self.batch_start = 0
self.batch_end = self.batch_size
raise StopIteration
else:
f_batch = self.f_data[self.batch_start:self.batch_end]
d_batch = self.d_data[self.batch_start:self.batch_end]
self.batch_start += self.batch_size
self.batch_end += self.batch_size
return f_batch
def __len__(self):
return self.length
def __getitem__(self, index):
f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size]
return f_batch
def get_peaks_skimage(self, frames):
regions = []
for frame in frames:
frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels)
labels = measure.label(frame_array)
regions = regionprops(labels)
for prop_nr,props in enumerate(regions):
if props.area < 4 or props.area > 150:
continue
y0,x0 = props.centroid
start_x = int(x0)-window
end_x = int(x0)+window+1
start_y = int(y0)-window
end_y = int(y0)+window+1
if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
continue
sub_img = np.copy(thisFrame)
sub_img[labels != prop_nr+1] = 0
sub_img = sub_img[start_y:end_y,start_x:end_x]
patches.append(sub_img)
xy_positions.append([start_y,start_x])
patches = np.array(patches)
xy_positions = np.array(xy_positions)
return patches
def normalize_patches(self, patches):
normalized_patches = []
for patch in patches:
patch = patch.astype(float)
patch /= patch.max()
patch *= 255
patch = patch.astype(int)
normalized_patches.append(patch)
return normalized_patches