File size: 3,683 Bytes
55ab5b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os 
import h5py 
import torch
import numpy as np
from numba import jit
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops

@jit
def process_frame(frame, dark, thresh):
    frame = frame - dark
    frame[frame < thresh] = 0
    frame = frame.astype(int)
    return frame

def normalize_patch(patch):
    patch = patch.astype(float)
    _min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
    feature   = (patch - _min) / (_max- _min) 
    return feature


class FrameDataset(Dataset): 
    def __init__(self, ffile, dfile, nFrames, batch_size, NrPixels=2048, thresh=100, fHead=8192, window=7):
        self.NrPixels = NrPixels
        self.batch_size = batch_size
        
        # Read dark frame
        with open(dfile, 'rb') as darkf:
            darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
            self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
            self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
            self.dark = self.dark.astype(float)
        darkf.close()
        
        # Read frames
        self.frames = []
        self.length = nFrames
        self.xy_positions = []
        self.patches=[]
        with open(ffile, 'rb') as f:
            for fNr in range(1,nFrames+1):
                BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
                f.seek(BytesToSkip,os.SEEK_SET)
                thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
                thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
                thisFrame = thisFrame.astype(float)
                thisFrame = process_frame(thisFrame, self.dark, thresh)
                thisFrame2 = np.copy(thisFrame)
                thisFrame2[thisFrame2>0] = 1
                labels = label(thisFrame2)
                regions = regionprops(labels)                                                       
                self.patches = []
                for prop_nr,props in enumerate(regions):
                    if props.area < 4 or props.area > 150:
                        continue
                    y0,x0   = props.centroid
                    start_x = int(x0)-window
                    end_x   = int(x0)+window+1
                    start_y = int(y0)-window
                    end_y   = int(y0)+window+1
                    if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
                        continue
                    sub_img = thisFrame[start_y:end_y,start_x:end_x]
                    self.patches.append(normalize_patch(sub_img))
                    self.xy_positions.append([start_y,start_x])
                print(np.shape(self.patches))
        f.close()

    def __iter__(self):
        self.batch_start = 0
        self.batch_end = self.batch_size
        return self
    
    def __next__(self):
        if self.batch_end > self.length:
            self.batch_start = 0
            self.batch_end = self.batch_size
            raise StopIteration
        else:
            p_batch  = self.p_batch[self.batch_start:self.batch_end]
            xy_batch = self.xy_batch[self.batch_start:self.batch_end]
            self.batch_start += self.batch_size
            self.batch_end += self.batch_size
            return p_batch, xy_batch 
    
    def __len__(self):
        return self.length
       
    def __getitem__(self, index):
        p_batch  = self.patches[index*self.batch_size:(index+1)*self.batch_size]
        xy_batch = self.xy_positions[index*self.batch_size:(index+1)*self.batch_size]
        return p_batch, xy_batch