File size: 4,737 Bytes
bb18256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from libs import *
import configVars
import ecg_plot
def remove_baseline_filter(sample_rate):
    fc = 0.8  # [Hz], cutoff frequency
    fst = 0.2  # [Hz], rejection band
    rp = 0.5  # [dB], ripple in passband
    rs = 40  # [dB], attenuation in rejection band
    wn = fc / (sample_rate / 2)
    wst = fst / (sample_rate / 2)

    filterorder, aux = sgn.ellipord(wn, wst, rp, rs)
    sos = sgn.iirfilter(filterorder, wn, rp, rs, btype='high', ftype='ellip', output='sos')

    return sos

reduced_leads = ['DI', 'DII', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
all_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

def preprocess_ecg(ecg, sample_rate, leads, scale=1,
                   use_all_leads=True, remove_baseline=False):
    # Remove baseline
    if remove_baseline:
        sos = remove_baseline_filter(sample_rate)
        ecg_nobaseline = sgn.sosfiltfilt(sos, ecg, padtype='constant', axis=-1)
    else:
        ecg_nobaseline = ecg

    # Rescale
    ecg_rescaled = scale * ecg_nobaseline
    
     # Resample
    if sample_rate != 500:
        ecg_resampled = sgn.resample_poly(ecg_rescaled, up=500, down=sample_rate, axis=-1)
    else:
        ecg_resampled = ecg_rescaled
    length = len(ecg_resampled[0])
    
    # Add leads if needed
    target_leads = all_leads if use_all_leads else reduced_leads
    n_leads_target = len(target_leads)
    l2p = dict(zip(target_leads, range(n_leads_target)))
    ecg_targetleads = np.zeros([n_leads_target, length])
    ecg_targetleads = ecg_rescaled
    if n_leads_target >= leads and use_all_leads:
        ecg_targetleads[l2p['DIII'], :] = ecg_targetleads[l2p['DII'], :] - ecg_targetleads[l2p['DI'], :]
        ecg_targetleads[l2p['AVR'], :] = -(ecg_targetleads[l2p['DI'], :] + ecg_targetleads[l2p['DII'], :]) / 2
        ecg_targetleads[l2p['AVL'], :] = (ecg_targetleads[l2p['DI'], :] - ecg_targetleads[l2p['DIII'], :]) / 2
        ecg_targetleads[l2p['AVF'], :] = (ecg_targetleads[l2p['DII'], :] + ecg_targetleads[l2p['DIII'], :]) / 2

    return ecg_targetleads


def generateH5(input_file,out_file,new_freq=None,new_len=None,scale=1,sample_rate=None):
    n = len(input_file)  # Get length
    try:
      h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'r+')
      h5f.clear()
    except:
      h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'w')

    # Resample
    if new_freq is not None:
        ecg_resampled = sgn.resample_poly(input_file, up=new_freq, down=sample_rate, axis=-1)
    else:
        ecg_resampled = input_file
        new_freq = sample_rate
    n_leads, length = ecg_resampled.shape
    
    # Rescale
    ecg_rescaled = scale * ecg_resampled
    
    # Reshape
    if new_len is None or new_len == length:
        ecg_reshaped = ecg_rescaled
    elif new_len > length:
        ecg_reshaped = np.zeros([n_leads, new_len])
        pad = (new_len - length) // 2
        ecg_reshaped[..., pad:length+pad] = ecg_rescaled
    else:
        extra = (length - new_len) // 2
        ecg_reshaped = ecg_rescaled[:, extra:new_len + extra]
    
    n_leads, n_samples = ecg_reshaped.shape
    x = h5f.create_dataset('tracings', (1, n_samples, n_leads), dtype='f8')
    x[0, :, :] = ecg_reshaped.T
    h5f.close()
    
def LightX3ECG(
    train_loaders, 
    config,
    save_ckp_dir, 
):
    model = torch.load(f"{save_ckp_dir}/best.ptl", map_location='cpu')
    #model = torch.load(f"{save_ckp_dir}/best.ptl", map_location = "cuda")
    model.to(torch.device('cpu'))
    with torch.no_grad():
        model.eval()
        running_preds = []
    
        for ecgs in train_loaders["pred"]:  
            ecgs = ecgs.cpu()
            logits = model(ecgs)
            preds = list(torch.max(logits, 1)[1].detach().cpu().numpy()) if not config["is_multilabel"] else list(torch.sigmoid(logits).detach().cpu().numpy())
            running_preds.extend(preds)
    
    if config["is_multilabel"]:
        running_preds =  np.array(running_preds)
        optimal_thresholds = pd.read_csv(f"{configVars.pathThresholds}CPSC-2018/optimal_thresholds_best.csv")
        preds = optimal_thresholds[optimal_thresholds["Threshold"]<=running_preds[0]]
        preds = preds['Pred'].values.tolist()
    else:   
        enfermedades = ['AFIB','GSVT','SB','SR']
        running_preds =  np.array(running_preds)
        #running_preds=np.reshape(running_preds, (len(running_preds),-1))
        preds = enfermedades[running_preds[0]]
    return preds

def ecgPlot(source,sample):
    data = np.load(source)
    #print(data)
    xml_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    ecg_plot.plot_12(data, sample_rate= sample,lead_index=xml_leads, title="Muestra")
    ecg_plot.save_as_png("ecg")