Spaces:
Sleeping
Sleeping
import os | |
import h5py | |
import torch | |
import numpy as np | |
from skimage import measure | |
from torchvision import transforms | |
from torch.utils.data import Dataset | |
from skimage.measure import label, regionprops | |
class FrameDataset(Dataset): | |
def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192): | |
self.NrPixels = NrPixels | |
self.batch_size = batch_size | |
# Read dark frame | |
with open(dfile, 'rb') as darkf: | |
darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET) | |
self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels)) | |
self.dark = np.reshape(self.dark,(NrPixels,NrPixels)) | |
self.dark = self.dark.astype(float) | |
# Read frames | |
self.frames = [] | |
self.length = nFrames | |
with open(ffile, 'rb') as f: | |
for _ in range(1, nFrames+1): # Skip first frame | |
BytesToSkip = fHead + fNr*NrPixels*NrPixels*2 | |
f.seek(BytesToSkip, os.SEEK_SET) | |
this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels)) | |
this_frame = np.reshape(this_frame, (NrPixels, NrPixels)) | |
this_frame = this_frame.astype(float) | |
this_frame = this_frame - self.dark | |
this_frame[this_frame < thresh] = 0 | |
thisFrame = thisFrame.astype(int) | |
self.frames.append(this_frame) | |
def __iter__(self): | |
self.batch_start = 0 | |
self.batch_end = self.batch_size | |
return self | |
def __next__(self): | |
if self.batch_end > self.length: | |
self.batch_start = 0 | |
self.batch_end = self.batch_size | |
raise StopIteration | |
else: | |
f_batch = self.f_data[self.batch_start:self.batch_end] | |
d_batch = self.d_data[self.batch_start:self.batch_end] | |
self.batch_start += self.batch_size | |
self.batch_end += self.batch_size | |
return f_batch | |
def __len__(self): | |
return self.length | |
def __getitem__(self, index): | |
f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size] | |
return f_batch | |
def get_peaks_skimage(self, frames): | |
regions = [] | |
for frame in frames: | |
frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels) | |
labels = measure.label(frame_array) | |
regions = regionprops(labels) | |
for prop_nr,props in enumerate(regions): | |
if props.area < 4 or props.area > 150: | |
continue | |
y0,x0 = props.centroid | |
start_x = int(x0)-window | |
end_x = int(x0)+window+1 | |
start_y = int(y0)-window | |
end_y = int(y0)+window+1 | |
if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1: | |
continue | |
sub_img = np.copy(thisFrame) | |
sub_img[labels != prop_nr+1] = 0 | |
sub_img = sub_img[start_y:end_y,start_x:end_x] | |
patches.append(sub_img) | |
xy_positions.append([start_y,start_x]) | |
patches = np.array(patches) | |
xy_positions = np.array(xy_positions) | |
return patches | |
def normalize_patches(self, patches): | |
normalized_patches = [] | |
for patch in patches: | |
patch = patch.astype(float) | |
patch /= patch.max() | |
patch *= 255 | |
patch = patch.astype(int) | |
normalized_patches.append(patch) | |
return normalized_patches | |