Agusbs98's picture
Upload 37 files
bb18256
raw
history blame
No virus
4.74 kB
from libs import *
import configVars
import ecg_plot
def remove_baseline_filter(sample_rate):
fc = 0.8 # [Hz], cutoff frequency
fst = 0.2 # [Hz], rejection band
rp = 0.5 # [dB], ripple in passband
rs = 40 # [dB], attenuation in rejection band
wn = fc / (sample_rate / 2)
wst = fst / (sample_rate / 2)
filterorder, aux = sgn.ellipord(wn, wst, rp, rs)
sos = sgn.iirfilter(filterorder, wn, rp, rs, btype='high', ftype='ellip', output='sos')
return sos
reduced_leads = ['DI', 'DII', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
all_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
def preprocess_ecg(ecg, sample_rate, leads, scale=1,
use_all_leads=True, remove_baseline=False):
# Remove baseline
if remove_baseline:
sos = remove_baseline_filter(sample_rate)
ecg_nobaseline = sgn.sosfiltfilt(sos, ecg, padtype='constant', axis=-1)
else:
ecg_nobaseline = ecg
# Rescale
ecg_rescaled = scale * ecg_nobaseline
# Resample
if sample_rate != 500:
ecg_resampled = sgn.resample_poly(ecg_rescaled, up=500, down=sample_rate, axis=-1)
else:
ecg_resampled = ecg_rescaled
length = len(ecg_resampled[0])
# Add leads if needed
target_leads = all_leads if use_all_leads else reduced_leads
n_leads_target = len(target_leads)
l2p = dict(zip(target_leads, range(n_leads_target)))
ecg_targetleads = np.zeros([n_leads_target, length])
ecg_targetleads = ecg_rescaled
if n_leads_target >= leads and use_all_leads:
ecg_targetleads[l2p['DIII'], :] = ecg_targetleads[l2p['DII'], :] - ecg_targetleads[l2p['DI'], :]
ecg_targetleads[l2p['AVR'], :] = -(ecg_targetleads[l2p['DI'], :] + ecg_targetleads[l2p['DII'], :]) / 2
ecg_targetleads[l2p['AVL'], :] = (ecg_targetleads[l2p['DI'], :] - ecg_targetleads[l2p['DIII'], :]) / 2
ecg_targetleads[l2p['AVF'], :] = (ecg_targetleads[l2p['DII'], :] + ecg_targetleads[l2p['DIII'], :]) / 2
return ecg_targetleads
def generateH5(input_file,out_file,new_freq=None,new_len=None,scale=1,sample_rate=None):
n = len(input_file) # Get length
try:
h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'r+')
h5f.clear()
except:
h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'w')
# Resample
if new_freq is not None:
ecg_resampled = sgn.resample_poly(input_file, up=new_freq, down=sample_rate, axis=-1)
else:
ecg_resampled = input_file
new_freq = sample_rate
n_leads, length = ecg_resampled.shape
# Rescale
ecg_rescaled = scale * ecg_resampled
# Reshape
if new_len is None or new_len == length:
ecg_reshaped = ecg_rescaled
elif new_len > length:
ecg_reshaped = np.zeros([n_leads, new_len])
pad = (new_len - length) // 2
ecg_reshaped[..., pad:length+pad] = ecg_rescaled
else:
extra = (length - new_len) // 2
ecg_reshaped = ecg_rescaled[:, extra:new_len + extra]
n_leads, n_samples = ecg_reshaped.shape
x = h5f.create_dataset('tracings', (1, n_samples, n_leads), dtype='f8')
x[0, :, :] = ecg_reshaped.T
h5f.close()
def LightX3ECG(
train_loaders,
config,
save_ckp_dir,
):
model = torch.load(f"{save_ckp_dir}/best.ptl", map_location='cpu')
#model = torch.load(f"{save_ckp_dir}/best.ptl", map_location = "cuda")
model.to(torch.device('cpu'))
with torch.no_grad():
model.eval()
running_preds = []
for ecgs in train_loaders["pred"]:
ecgs = ecgs.cpu()
logits = model(ecgs)
preds = list(torch.max(logits, 1)[1].detach().cpu().numpy()) if not config["is_multilabel"] else list(torch.sigmoid(logits).detach().cpu().numpy())
running_preds.extend(preds)
if config["is_multilabel"]:
running_preds = np.array(running_preds)
optimal_thresholds = pd.read_csv(f"{configVars.pathThresholds}CPSC-2018/optimal_thresholds_best.csv")
preds = optimal_thresholds[optimal_thresholds["Threshold"]<=running_preds[0]]
preds = preds['Pred'].values.tolist()
else:
enfermedades = ['AFIB','GSVT','SB','SR']
running_preds = np.array(running_preds)
#running_preds=np.reshape(running_preds, (len(running_preds),-1))
preds = enfermedades[running_preds[0]]
return preds
def ecgPlot(source,sample):
data = np.load(source)
#print(data)
xml_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
ecg_plot.plot_12(data, sample_rate= sample,lead_index=xml_leads, title="Muestra")
ecg_plot.save_as_png("ecg")