Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import traceback as tb | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
from model import HRNetV2Wrapper | |
# True 이면, tmp directory 에 파일 존재 유무와 상관없이 항상 새로운 이미지 생성 | |
ALWAYS_RECREATE_IMAGE = os.getenv("ALWAYS_RECREATE_IMAGE", "False").lower() == "true" | |
selected_columns = ["subject_id", "no_p", "Rhythm", "Electric axis of the heart", "Etc"] | |
train_df = pd.read_csv("./res/ludb/dataset/train_for_public.csv").drop_duplicates( | |
subset=["subject_id"] | |
)[selected_columns] | |
valid_df = pd.read_csv("./res/ludb/dataset/valid_for_public.csv").drop_duplicates( | |
subset=["subject_id"] | |
)[selected_columns] | |
test_df = pd.read_csv("./res/ludb/dataset/test_for_public.csv").drop_duplicates( | |
subset=["subject_id"] | |
)[selected_columns] | |
cutoffs = [0.001163482666015625, 0.15087890625, -0.587890625] | |
lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] | |
hrnetv2_wrapper = HRNetV2Wrapper() | |
def gen_seg(subject_id): | |
input = np.load(f"./res/ludb/ecg_np/{subject_id}.npy") | |
output: torch.Tensor = ( | |
hrnetv2_wrapper.model(torch.from_numpy(input)).detach().numpy() | |
) | |
seg = [(output[:, i, :] >= cutoffs[i]).astype(int) for i in range(len(cutoffs))] | |
return input, np.stack(seg, axis=1) | |
def concat_short_interval(seg, th): | |
"""seg에서 구간(1)과 구간(1) 사이에 th 보다 짧은 부분(0)을 이어 붙인다. (0 -> 1)""" | |
# seg 에서 같은 구간끼리 그룹을 만듦. ex: seg = [0, 0, 1, 1, 0, 1, 1, 1, 1] -> seg_groups=[[0, 0], [1, 1], [0], [1, 1, 1, 1]]] | |
seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1) | |
for i in range(1, len(seg_groups) - 1): # 첫 번째와 마지막 그룹 제외 | |
group = seg_groups[i] | |
if len(group) <= th and np.all(group == 0): | |
seg_groups[i] = np.ones_like(group) # 0 -> 1 | |
return np.concatenate(seg_groups) | |
def remove_short_duration(seg, th): | |
"""seg의 구간(1)중에 th 보다 짧은 구간은 제거 (1 -> 0)""" | |
seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1) | |
for i, group in enumerate(seg_groups): | |
if len(group) <= th and np.all(group == 1): | |
seg_groups[i] = np.zeros_like(group) # 1 -> 0 | |
return np.concatenate(seg_groups) | |
def gen_each_image(input, seg, image_path, ths=[5, 25, 25, 25, 15, 25], pp=False): | |
fig = plt.figure(figsize=(15, 18)) | |
plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02, hspace=0.2) | |
for idx, (in_by_lead, seg_by_lead) in enumerate(zip(input, seg)): | |
sub_fig = fig.add_subplot(12, 1, idx + 1) | |
sub_fig.text( | |
0.02, | |
0.5, | |
f"{lead_names[idx]}", | |
fontsize=9, | |
fontweight="bold", | |
ha="center", | |
va="center", | |
rotation=90, | |
transform=sub_fig.transAxes, | |
) | |
sub_fig.set_xticks([]) | |
sub_fig.set_yticks([]) | |
sub_fig.plot( | |
range(len(in_by_lead[0])), in_by_lead[0], color="black", linewidth=1.0 | |
) | |
p_seg = seg_by_lead[0] | |
qrs_seg = seg_by_lead[1] | |
t_seg = seg_by_lead[2] | |
if pp: | |
p_seg = remove_short_duration(concat_short_interval(p_seg, ths[0]), ths[1]) | |
qrs_seg = remove_short_duration( | |
concat_short_interval(qrs_seg, ths[2]), ths[3] | |
) | |
t_seg = remove_short_duration(concat_short_interval(t_seg, ths[4]), ths[5]) | |
sub_fig.plot( | |
range(len(p_seg)), p_seg / 2, label="P", color="red", linewidth=0.7 | |
) | |
sub_fig.plot( | |
range(len(qrs_seg)), qrs_seg, label="QRS", color="green", linewidth=0.7 | |
) | |
sub_fig.plot( | |
range(len(t_seg)), t_seg / 2, label="T", color="blue", linewidth=0.7 | |
) | |
plt.savefig(image_path, dpi=150) | |
plt.close() | |
def gen_image(subject_id, image_path, pp_image_path): | |
try: | |
input, seg = gen_seg(subject_id) | |
gen_each_image(input, seg, image_path) | |
gen_each_image(input, seg, pp_image_path, pp=True) | |
return True | |
except Exception: | |
print(tb.format_exc()) | |
return False | |
with gr.Blocks() as demo: | |
with gr.Tab("App"): | |
with gr.Row(): | |
gr.Textbox( | |
"""Welcome to visit ECG Delineation space. | |
The following three tables represent the train, validation, and test datasets, which have been meticulously stratified from the LUDB dataset. These datasets were used for training and evaluating the models. | |
Usage: By clicking on the desired record in one of the tables below, the P, QRS, and T wave segments will be inferred by HRNetV2 and displayed as an image at the bottom. Additionally, the post-processed results based on predefined thresholds will also be displayed alongside.""", | |
label="Information", | |
lines=3, | |
) | |
gr_dfs = [] | |
with gr.Row(): | |
gr_dfs.append( | |
gr.Dataframe( | |
value=train_df, | |
interactive=False, | |
max_height=250, | |
label="our train dataset. (source: ./res/ludb/dataset/train_for_public.csv)", | |
) | |
) | |
with gr.Row(): | |
gr_dfs.append( | |
gr.Dataframe( | |
value=valid_df, | |
interactive=False, | |
max_height=250, | |
label="our valid dataset. (source: ./res/ludb/dataset/valid_for_public.csv)", | |
) | |
) | |
with gr.Row(): | |
gr_dfs.append( | |
gr.Dataframe( | |
value=test_df, | |
interactive=False, | |
max_height=250, | |
label="our test dataset. (source: ./res/ludb/dataset/test_for_public.csv)", | |
) | |
) | |
with gr.Row(): | |
gr_image = gr.Image(type="filepath", label="Output") | |
gr_pp_image = gr.Image(type="filepath", label="PostProcessed Output") | |
def show_image(df: pd.DataFrame, evt: gr.SelectData): | |
subject_id = evt.row_value[0] | |
image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}.png" | |
pp_image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}_pp.png" | |
if not ALWAYS_RECREATE_IMAGE and ( | |
os.path.exists(image_path) and os.path.exists(pp_image_path) | |
): | |
return [image_path, pp_image_path] | |
gen_image(subject_id, image_path, pp_image_path) | |
return [image_path, pp_image_path] | |
for gr_df in gr_dfs: | |
gr_df.select(fn=show_image, inputs=[gr_df], outputs=[gr_image, gr_pp_image]) | |
demo.launch() | |