import os import h5py import torch import numpy as np from numba import jit from skimage import measure from torchvision import transforms from torch.utils.data import Dataset from skimage.measure import label, regionprops @jit def process_frame(frame, dark, thresh): frame = frame - dark frame[frame < thresh] = 0 frame = frame.astype(int) return frame def normalize_patch(patch): patch = patch.astype(float) _min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32) feature = (patch - _min) / (_max- _min) return feature class FrameDataset(Dataset): def __init__(self, ffile, dfile, nFrames, batch_size, NrPixels=2048, thresh=100, fHead=8192, window=7): self.NrPixels = NrPixels self.batch_size = batch_size # Read dark frame with open(dfile, 'rb') as darkf: darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET) self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels)) self.dark = np.reshape(self.dark,(NrPixels,NrPixels)) self.dark = self.dark.astype(float) darkf.close() # Read frames self.frames = [] self.length = nFrames self.xy_positions = [] self.patches=[] with open(ffile, 'rb') as f: for fNr in range(1,nFrames+1): BytesToSkip = fHead + fNr*NrPixels*NrPixels*2 f.seek(BytesToSkip,os.SEEK_SET) thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels)) thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels)) thisFrame = thisFrame.astype(float) thisFrame = process_frame(thisFrame, self.dark, thresh) thisFrame2 = np.copy(thisFrame) thisFrame2[thisFrame2>0] = 1 labels = label(thisFrame2) regions = regionprops(labels) self.patches = [] for prop_nr,props in enumerate(regions): if props.area < 4 or props.area > 150: continue y0,x0 = props.centroid start_x = int(x0)-window end_x = int(x0)+window+1 start_y = int(y0)-window end_y = int(y0)+window+1 if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1: continue sub_img = thisFrame[start_y:end_y,start_x:end_x] self.patches.append(normalize_patch(sub_img)) self.xy_positions.append([start_y,start_x]) print(np.shape(self.patches)) f.close() def __iter__(self): self.batch_start = 0 self.batch_end = self.batch_size return self def __next__(self): if self.batch_end > self.length: self.batch_start = 0 self.batch_end = self.batch_size raise StopIteration else: p_batch = self.p_batch[self.batch_start:self.batch_end] xy_batch = self.xy_batch[self.batch_start:self.batch_end] self.batch_start += self.batch_size self.batch_end += self.batch_size return p_batch, xy_batch def __len__(self): return self.length def __getitem__(self, index): p_batch = self.patches[index*self.batch_size:(index+1)*self.batch_size] xy_batch = self.xy_positions[index*self.batch_size:(index+1)*self.batch_size] return p_batch, xy_batch