Heartbeat / app.py
justus-tobias's picture
fixed issue with gradio df
b312806
raw
history blame
9.55 kB
from plotly.subplots import make_subplots
from scipy.signal import find_peaks, butter, filtfilt
import plotly.graph_objects as go
from io import StringIO
import pandas as pd
import gradio as gr
import numpy as np
import itertools
import tempfile
import librosa
import random
import mdpd
import os
from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2
example_dir = "Examples"
example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))]
all_pairs = list(itertools.combinations(example_files, 2))
random.shuffle(all_pairs)
example_pairs = [list(pair) for pair in all_pairs[:25]]
def getHRV(beattimes: np.ndarray) -> np.ndarray:
# Calculate instantaneous heart rate
instantaneous_hr = 60 * np.diff(beattimes)
# # Calculate moving average heart rate (e.g., over 10 beats)
# window_size = 10
# moving_avg_hr = np.convolve(instantaneous_hr, np.ones(window_size), 'valid') / window_size
# # Calculate heart rate variability as the difference from the moving average
# hrv = instantaneous_hr[window_size-1:] - moving_avg_hr
return instantaneous_hr
def create_average_heartbeat(audiodata, sr):
# 1. Detect individual heartbeats
onset_env = librosa.onset.onset_strength(y=audiodata, sr=sr)
peaks, _ = find_peaks(onset_env, distance=sr//2) # Assume at least 0.5s between beats
# 2. Extract individual heartbeats
beat_length = sr # Assume 1 second for each beat
beats = []
for peak in peaks:
if peak + beat_length < len(audiodata):
beat = audiodata[peak:peak+beat_length]
beats.append(beat)
# 3. Align and average the beats
if beats:
avg_beat = np.mean(beats, axis=0)
else:
avg_beat = np.array([])
# 4. Create a Plotly figure of the average heartbeat
time = np.arange(len(avg_beat)) / sr
fig = go.Figure()
fig.add_trace(go.Scatter(x=time, y=avg_beat, mode='lines', name='Average Beat'))
fig.update_layout(
title='Average Heartbeat',
xaxis_title='Time (s)',
yaxis_title='Amplitude'
)
return fig, avg_beat
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS
def plotCombined(audiodata, sr, filename):
# Get beat times
tempo, beattimes = getBeats(audiodata, sr)
# Create subplots
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.1,
subplot_titles=('Audio Waveform', 'Spectrogram', 'Heart Rate Variability'))
# Time array for the full audio
time = (np.arange(0, len(audiodata)) / sr) * 2
# Waveform plot
fig.add_trace(
go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1)),
row=1, col=1
)
# Add beat markers
beat_amplitudes = np.interp(beattimes, time, audiodata)
fig.add_trace(
go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats',
marker=dict(color='red', size=8, symbol='circle')),
row=1, col=1
)
# HRV plot
hrv = getHRV(beattimes)
hrv_time = beattimes[1:len(hrv)+1]
fig.add_trace(
go.Scatter(x=hrv_time, y=hrv, mode='lines', name='HRV', line=dict(color='green', width=1)),
row=3, col=1
)
# Spectrogram plot
n_fft = 2048 # You can adjust this value
hop_length = n_fft // 4 # You can adjust this value
D = librosa.stft(audiodata, n_fft=n_fft, hop_length=hop_length)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
# Calculate the correct time array for the spectrogram
spec_times = librosa.times_like(S_db, sr=sr, hop_length=hop_length)
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
fig.add_trace(
go.Heatmap(z=S_db, x=spec_times, y=freqs, colorscale='Viridis',
zmin=S_db.min(), zmax=S_db.max(), colorbar=dict(title='Magnitude (dB)')),
row=2, col=1
)
# Update layout
fig.update_layout(
height=1000,
title_text=filename,
showlegend=False
)
fig.update_xaxes(title_text="Time (s)", row=2, col=1)
fig.update_xaxes(range=[0, len(audiodata)/sr], row=2, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
fig.update_yaxes(title_text="HRV", row=3, col=1)
fig.update_yaxes(title_text="Frequency (Hz)", type="log", row=2, col=1)
return fig
def analyze_single(audio:gr.Audio):
# Extract audio data and sample rate
filepath = audio
filename = filepath.split("/")[-1]
sr, audiodata = getaudiodata(filepath)
# Now you have:
# - audiodata: a 1D numpy array containing the audio samples
# - sr: the sample rate of the audio
# Your analysis code goes here
# For example, you could print basic information:
# print(f"Audio length: {len(audiodata) / sr:.2f} seconds")
# print(f"Sample rate: {sr} Hz")
zcr = librosa.feature.zero_crossing_rate(audiodata)[0]
# print(f"Mean Zero Crossing Rate: {np.mean(zcr):.4f}")
# Calculate RMS Energy
rms = librosa.feature.rms(y=audiodata)[0]
# print(f"Mean RMS Energy: {np.mean(rms):.4f}")
tempo, beattimes = getBeats(audiodata, sr)
spectogram_wave = plotCombined(audiodata, sr, filename)
#beats_histogram = plotbeatscatter(tempo[0], beattimes)
# Add the new average heartbeat analysis
avg_beat_plot, avg_beat = create_average_heartbeat(audiodata, sr)
# Calculate some statistics about the average beat
avg_beat_duration = len(avg_beat) / sr
avg_beat_energy = np.sum(np.square(avg_beat))
# Return your analysis results
results = f"""
Average Heartbeat Analysis:
- Duration: {avg_beat_duration:.3f} seconds
- Energy: {avg_beat_energy:.3f}
- Audio length: {len(audiodata) / sr:.2f} seconds
- Sample rate: {sr} Hz
- Mean Zero Crossing Rate: {np.mean(zcr):.4f}
- Mean RMS Energy: {np.mean(rms):.4f}
- Tempo: {tempo[0]:.4f}
- Beats: {beattimes}
- Beat durations: {np.diff(beattimes)}
- Mean Beat Duration: {np.mean(np.diff(beattimes)):.4f}
"""
return results, spectogram_wave, avg_beat_plot
#-----------------------------------------------
#-----------------------------------------------
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2
def getBeatsv2(audio:gr.Audio):
sr, audiodata = getaudiodata(audio)
_, beattimes, audiodata = getBeats(audiodata, sr)
beattimes_table = pd.DataFrame(data={"Beattimes":beattimes})
feature_array = find_s1s2(beattimes_table)
featuredf = pd.DataFrame(
data=feature_array,
columns=[
"Beattimes",
"S1 to S2",
"S2 to S1",
"Label (S1=0/S2=1)"]
)
# Create boolean masks for each label
mask_ones = feature_array[:, 3] == 1
mask_zeros = feature_array[:, 3] == 0
# Extract time/positions using the masks
times_label_one = feature_array[mask_ones, 0]
times_label_zero = feature_array[mask_zeros, 0]
fig = plotBeattimes(times_label_one, audiodata, sr, times_label_zero)
return fig, featuredf.to_markdown(), (sr, audiodata)
def updateBeatsv2(beattimes_table:gr.Markdown, audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure:
df = mdpd.from_md(df)
sr, audiodata = getaudiodata(audio)
if uploadeddf != None:
beattimes_table = pd.read_csv(uploadeddf)
s1_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 0]["Beattimes"].to_numpy()
s2_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 1]["Beattimes"].to_numpy()
fig = plotBeattimes(s1_times, audiodata, sr, s2_times)
return fig, beattimes_table.to_markdown()
def download_df (df: str):
df = mdpd.from_md(df)
print(df)
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, "feature_data.csv")
df.to_csv(temp_path, index=False)
return temp_path
with gr.Blocks() as app:
gr.Markdown("# Heartbeat")
gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios")
audiofile = gr.Audio(
type="filepath",
label="Upload the Audio of a Heartbeat",
sources="upload")
with gr.Tab("Preprocessing"):
getBeatsbtn = gr.Button("get Beats")
cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True)
beats_wave_plot = gr.Plot()
with gr.Row():
with gr.Column():
beattimes_table = gr.Markdown()
with gr.Column():
csv_download = gr.DownloadButton()
updateBeatsbtn = gr.Button("update Beats")
uploadDF = gr.File(
file_count="single",
file_types=[".csv"],
label="upload a csv",
height=25
)
csv_download.click(download_df, inputs=[beattimes_table], outputs=[csv_download])
getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table, cleanedaudio])
updateBeatsbtn.click(updateBeatsv2, inputs=[beattimes_table, audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_table])
gr.Examples(
examples=example_files,
inputs=audiofile,
fn=getBeatsv2,
cache_examples=False
)
with gr.Tab("Analysis"):
gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'")
app.launch()