import os import h5py import torch import numpy as np from skimage import measure from torchvision import transforms from torch.utils.data import Dataset from skimage.measure import label, regionprops class FrameDataset(Dataset): def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192): self.NrPixels = NrPixels self.batch_size = batch_size # Read dark frame with open(dfile, 'rb') as darkf: darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET) self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels)) self.dark = np.reshape(self.dark,(NrPixels,NrPixels)) self.dark = self.dark.astype(float) # Read frames self.frames = [] self.length = nFrames with open(ffile, 'rb') as f: for _ in range(1, nFrames+1): # Skip first frame BytesToSkip = fHead + fNr*NrPixels*NrPixels*2 f.seek(BytesToSkip, os.SEEK_SET) this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels)) this_frame = np.reshape(this_frame, (NrPixels, NrPixels)) this_frame = this_frame.astype(float) this_frame = this_frame - self.dark this_frame[this_frame < thresh] = 0 thisFrame = thisFrame.astype(int) self.frames.append(this_frame) def __iter__(self): self.batch_start = 0 self.batch_end = self.batch_size return self def __next__(self): if self.batch_end > self.length: self.batch_start = 0 self.batch_end = self.batch_size raise StopIteration else: f_batch = self.f_data[self.batch_start:self.batch_end] d_batch = self.d_data[self.batch_start:self.batch_end] self.batch_start += self.batch_size self.batch_end += self.batch_size return f_batch def __len__(self): return self.length def __getitem__(self, index): f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size] return f_batch def get_peaks_skimage(self, frames): regions = [] for frame in frames: frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels) labels = measure.label(frame_array) regions = regionprops(labels) for prop_nr,props in enumerate(regions): if props.area < 4 or props.area > 150: continue y0,x0 = props.centroid start_x = int(x0)-window end_x = int(x0)+window+1 start_y = int(y0)-window end_y = int(y0)+window+1 if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1: continue sub_img = np.copy(thisFrame) sub_img[labels != prop_nr+1] = 0 sub_img = sub_img[start_y:end_y,start_x:end_x] patches.append(sub_img) xy_positions.append([start_y,start_x]) patches = np.array(patches) xy_positions = np.array(xy_positions) return patches def normalize_patches(self, patches): normalized_patches = [] for patch in patches: patch = patch.astype(float) patch /= patch.max() patch *= 255 patch = patch.astype(int) normalized_patches.append(patch) return normalized_patches