BNNPreprocess / ds1.py
Dennis Trujillo
good fix, len(dataset.patches) == len(dataset)
5802303
import os
import h5py
import torch
import numpy as np
from numba import jit
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops
@jit
def process_frame(frame, dark, thresh):
frame = frame - dark
frame[frame < thresh] = 0
frame = frame.astype(int)
return frame
def normalize_patch(patch):
patch = patch.astype(float)
_min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
feature = (patch - _min) / (_max- _min)
return feature
class PatchDataset(Dataset):
def __init__(self, ffile, dfile, nFrames, NrPixels=2048, thresh=100, fHead=8192, window=7):
self.NrPixels = NrPixels
# Read dark frame
with open(dfile, 'rb') as darkf:
darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
self.dark = self.dark.astype(float)
darkf.close()
# Read frames
self.frames = []
self.length = 0
self.xy_positions = []
self.patches=[]
self.f_nums=[]
with open(ffile, 'rb') as f:
for fNr in range(1,nFrames+1):
BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
f.seek(BytesToSkip,os.SEEK_SET)
thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
thisFrame = thisFrame.astype(float)
thisFrame = process_frame(thisFrame, self.dark, thresh)
thisFrame2 = np.copy(thisFrame)
thisFrame2[thisFrame2>0] = 1
labels = label(thisFrame2)
regions = regionprops(labels)
for prop_nr,props in enumerate(regions):
if props.area < 4 or props.area > 150:
continue
y0,x0 = props.centroid
start_x = int(x0)-window
end_x = int(x0)+window+1
start_y = int(y0)-window
end_y = int(y0)+window+1
if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
continue
sub_img = thisFrame[start_y:end_y,start_x:end_x]
self.patches.append(normalize_patch(sub_img))
self.xy_positions.append([start_y,start_x])
self.f_nums.append(fNr)
#print(len(self.patches))
self.length = len(self.patches)
f.close()
def __len__(self):
return self.length
def __getitem__(self, index):
return self.patches[index]#, self.xy_positions[index], self.f_nums[index]