File size: 2,984 Bytes
55ab5b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6459c2a
55ab5b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6459c2a
5802303
 
55ab5b2
 
 
 
 
 
5802303
55ab5b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os 
import h5py 
import torch
import numpy as np
from numba import jit
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops

@jit
def process_frame(frame, dark, thresh):
    frame = frame - dark
    frame[frame < thresh] = 0
    frame = frame.astype(int)
    return frame

def normalize_patch(patch):
    patch = patch.astype(float)
    _min,_max = patch.min().astype(np.float32), patch.max().astype(np.float32)
    feature   = (patch - _min) / (_max- _min) 
    return feature


class PatchDataset(Dataset): 
    def __init__(self, ffile, dfile, nFrames, NrPixels=2048, thresh=100, fHead=8192, window=7):
        self.NrPixels = NrPixels

        # Read dark frame
        with open(dfile, 'rb') as darkf:
            darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
            self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
            self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
            self.dark = self.dark.astype(float)
        darkf.close()

        # Read frames
        self.frames = []
        self.length = 0
        self.xy_positions = []
        self.patches=[]
        self.f_nums=[]
        with open(ffile, 'rb') as f:
            for fNr in range(1,nFrames+1):
                BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
                f.seek(BytesToSkip,os.SEEK_SET)
                thisFrame = np.fromfile(f,dtype=np.uint16,count=(NrPixels*NrPixels))
                thisFrame = np.reshape(thisFrame,(NrPixels,NrPixels))
                thisFrame = thisFrame.astype(float)
                thisFrame = process_frame(thisFrame, self.dark, thresh)
                thisFrame2 = np.copy(thisFrame)
                thisFrame2[thisFrame2>0] = 1
                labels = label(thisFrame2)
                regions = regionprops(labels)                                                       
                for prop_nr,props in enumerate(regions):
                    if props.area < 4 or props.area > 150:
                        continue
                    y0,x0   = props.centroid
                    start_x = int(x0)-window
                    end_x   = int(x0)+window+1
                    start_y = int(y0)-window
                    end_y   = int(y0)+window+1
                    if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
                        continue
                    sub_img = thisFrame[start_y:end_y,start_x:end_x]
                    self.patches.append(normalize_patch(sub_img))
                    self.xy_positions.append([start_y,start_x])
                    self.f_nums.append(fNr)
                    #print(len(self.patches))
                self.length = len(self.patches)
        f.close()

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        return self.patches[index]#, self.xy_positions[index], self.f_nums[index]