vlbthambawita
commited on
Commit
•
b5953e8
1
Parent(s):
66eabd8
ecg plot
Browse files- app.py +23 -6
- requirements.txt +4 -1
app.py
CHANGED
@@ -1,17 +1,34 @@
|
|
1 |
import gradio as gr
|
2 |
#from transformers import pipeline
|
3 |
from transformers import AutoModel
|
4 |
-
|
|
|
|
|
5 |
#pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
|
6 |
model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
|
7 |
|
8 |
def predict(num_ecgs):
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
gr.Interface(
|
13 |
predict,
|
14 |
-
inputs=
|
15 |
-
outputs="
|
16 |
title="Generating ECGs",
|
17 |
-
).launch()
|
|
|
1 |
import gradio as gr
|
2 |
#from transformers import pipeline
|
3 |
from transformers import AutoModel
|
4 |
+
import ecg_plot
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
#pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
|
8 |
model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
|
9 |
|
10 |
def predict(num_ecgs):
|
11 |
+
prediction = (model(1)[0].t()/1000) # to micro volte
|
12 |
+
|
13 |
+
|
14 |
+
lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0)
|
15 |
+
lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0)
|
16 |
+
lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0)
|
17 |
+
lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0)
|
18 |
+
all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0)
|
19 |
+
all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])]
|
20 |
+
|
21 |
+
ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12')
|
22 |
+
|
23 |
+
#ecg_plot.show()
|
24 |
+
buf = io.BytesIO()
|
25 |
+
plt.savefig(buf, format="png")
|
26 |
+
img = Image.open(buf)
|
27 |
+
return img
|
28 |
|
29 |
gr.Interface(
|
30 |
predict,
|
31 |
+
inputs=None,
|
32 |
+
outputs="image",
|
33 |
title="Generating ECGs",
|
34 |
+
).launch(share=True)
|
requirements.txt
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
transformers
|
2 |
-
torch
|
|
|
|
|
|
|
|
1 |
transformers
|
2 |
+
torch
|
3 |
+
ecg-plot
|
4 |
+
matplotlib
|
5 |
+
PIL
|