Spaces:
Runtime error
Runtime error
from libs import * | |
import configVars | |
import ecg_plot | |
def remove_baseline_filter(sample_rate): | |
fc = 0.8 # [Hz], cutoff frequency | |
fst = 0.2 # [Hz], rejection band | |
rp = 0.5 # [dB], ripple in passband | |
rs = 40 # [dB], attenuation in rejection band | |
wn = fc / (sample_rate / 2) | |
wst = fst / (sample_rate / 2) | |
filterorder, aux = sgn.ellipord(wn, wst, rp, rs) | |
sos = sgn.iirfilter(filterorder, wn, rp, rs, btype='high', ftype='ellip', output='sos') | |
return sos | |
reduced_leads = ['DI', 'DII', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
all_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
def preprocess_ecg(ecg, sample_rate, leads, scale=1, | |
use_all_leads=True, remove_baseline=False): | |
# Remove baseline | |
if remove_baseline: | |
sos = remove_baseline_filter(sample_rate) | |
ecg_nobaseline = sgn.sosfiltfilt(sos, ecg, padtype='constant', axis=-1) | |
else: | |
ecg_nobaseline = ecg | |
# Rescale | |
ecg_rescaled = scale * ecg_nobaseline | |
# Resample | |
if sample_rate != 500: | |
ecg_resampled = sgn.resample_poly(ecg_rescaled, up=500, down=sample_rate, axis=-1) | |
else: | |
ecg_resampled = ecg_rescaled | |
length = len(ecg_resampled[0]) | |
# Add leads if needed | |
target_leads = all_leads if use_all_leads else reduced_leads | |
n_leads_target = len(target_leads) | |
l2p = dict(zip(target_leads, range(n_leads_target))) | |
ecg_targetleads = np.zeros([n_leads_target, length]) | |
ecg_targetleads = ecg_rescaled | |
if n_leads_target >= leads and use_all_leads: | |
ecg_targetleads[l2p['DIII'], :] = ecg_targetleads[l2p['DII'], :] - ecg_targetleads[l2p['DI'], :] | |
ecg_targetleads[l2p['AVR'], :] = -(ecg_targetleads[l2p['DI'], :] + ecg_targetleads[l2p['DII'], :]) / 2 | |
ecg_targetleads[l2p['AVL'], :] = (ecg_targetleads[l2p['DI'], :] - ecg_targetleads[l2p['DIII'], :]) / 2 | |
ecg_targetleads[l2p['AVF'], :] = (ecg_targetleads[l2p['DII'], :] + ecg_targetleads[l2p['DIII'], :]) / 2 | |
return ecg_targetleads | |
def generateH5(input_file,out_file,new_freq=None,new_len=None,scale=1,sample_rate=None): | |
n = len(input_file) # Get length | |
try: | |
h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'r+') | |
h5f.clear() | |
except: | |
h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'w') | |
# Resample | |
if new_freq is not None: | |
ecg_resampled = sgn.resample_poly(input_file, up=new_freq, down=sample_rate, axis=-1) | |
else: | |
ecg_resampled = input_file | |
new_freq = sample_rate | |
n_leads, length = ecg_resampled.shape | |
# Rescale | |
ecg_rescaled = scale * ecg_resampled | |
# Reshape | |
if new_len is None or new_len == length: | |
ecg_reshaped = ecg_rescaled | |
elif new_len > length: | |
ecg_reshaped = np.zeros([n_leads, new_len]) | |
pad = (new_len - length) // 2 | |
ecg_reshaped[..., pad:length+pad] = ecg_rescaled | |
else: | |
extra = (length - new_len) // 2 | |
ecg_reshaped = ecg_rescaled[:, extra:new_len + extra] | |
n_leads, n_samples = ecg_reshaped.shape | |
x = h5f.create_dataset('tracings', (1, n_samples, n_leads), dtype='f8') | |
x[0, :, :] = ecg_reshaped.T | |
h5f.close() | |
def LightX3ECG( | |
train_loaders, | |
config, | |
save_ckp_dir, | |
): | |
model = torch.load(f"{save_ckp_dir}/best.ptl", map_location='cpu') | |
#model = torch.load(f"{save_ckp_dir}/best.ptl", map_location = "cuda") | |
model.to(torch.device('cpu')) | |
with torch.no_grad(): | |
model.eval() | |
running_preds = [] | |
for ecgs in train_loaders["pred"]: | |
ecgs = ecgs.cpu() | |
logits = model(ecgs) | |
preds = list(torch.max(logits, 1)[1].detach().cpu().numpy()) if not config["is_multilabel"] else list(torch.sigmoid(logits).detach().cpu().numpy()) | |
running_preds.extend(preds) | |
if config["is_multilabel"]: | |
running_preds = np.array(running_preds) | |
optimal_thresholds = pd.read_csv(f"{configVars.pathThresholds}CPSC-2018/optimal_thresholds_best.csv") | |
preds = optimal_thresholds[optimal_thresholds["Threshold"]<=running_preds[0]] | |
preds = preds['Pred'].values.tolist() | |
else: | |
enfermedades = ['AFIB','GSVT','SB','SR'] | |
running_preds = np.array(running_preds) | |
#running_preds=np.reshape(running_preds, (len(running_preds),-1)) | |
preds = enfermedades[running_preds[0]] | |
return preds | |
def ecgPlot(source,sample): | |
data = np.load(source) | |
#print(data) | |
xml_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
ecg_plot.plot_12(data, sample_rate= sample,lead_index=xml_leads, title="Muestra") | |
ecg_plot.save_as_png("ecg") |