from plotly.subplots import make_subplots
from scipy.signal import find_peaks, butter, filtfilt
import plotly.graph_objects as go
from io import StringIO
import pandas as pd
import gradio as gr
import numpy as np
import itertools
import tempfile
import librosa
import random
import mdpd
import os
#from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2
import utils as u
example_dir = "Examples"
example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))]
all_pairs = list(itertools.combinations(example_files, 2))
random.shuffle(all_pairs)
example_pairs = [list(pair) for pair in all_pairs[:25]]
#-----------------------------------------------
# PROCESSING FUNCTIONS
#-----------------------------------------------
def getHRV(beattimes: np.ndarray) -> np.ndarray:
# Calculate instantaneous heart rate
instantaneous_hr = 60 * np.diff(beattimes)
# # Calculate moving average heart rate (e.g., over 10 beats)
# window_size = 10
# moving_avg_hr = np.convolve(instantaneous_hr, np.ones(window_size), 'valid') / window_size
# # Calculate heart rate variability as the difference from the moving average
# hrv = instantaneous_hr[window_size-1:] - moving_avg_hr
return instantaneous_hr
def create_average_heartbeat(audiodata, sr):
# 1. Detect individual heartbeats
onset_env = librosa.onset.onset_strength(y=audiodata, sr=sr)
peaks, _ = find_peaks(onset_env, distance=sr//2) # Assume at least 0.5s between beats
# 2. Extract individual heartbeats
beat_length = sr # Assume 1 second for each beat
beats = []
for peak in peaks:
if peak + beat_length < len(audiodata):
beat = audiodata[peak:peak+beat_length]
beats.append(beat)
# 3. Align and average the beats
if beats:
avg_beat = np.mean(beats, axis=0)
else:
avg_beat = np.array([])
# 4. Create a Plotly figure of the average heartbeat
time = np.arange(len(avg_beat)) / sr
fig = go.Figure()
fig.add_trace(go.Scatter(x=time, y=avg_beat, mode='lines', name='Average Beat'))
fig.update_layout(
title='Average Heartbeat',
xaxis_title='Time (s)',
yaxis_title='Amplitude'
)
return fig, avg_beat
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS
def plotCombined(audiodata, sr, filename):
# Get beat times
tempo, beattimes = u.getBeats(audiodata, sr)
# Create subplots
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.1,
subplot_titles=('Audio Waveform', 'Spectrogram', 'Heart Rate Variability'))
# Time array for the full audio
time = (np.arange(0, len(audiodata)) / sr) * 2
# Waveform plot
fig.add_trace(
go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1)),
row=1, col=1
)
# Add beat markers
beat_amplitudes = np.interp(beattimes, time, audiodata)
fig.add_trace(
go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats',
marker=dict(color='red', size=8, symbol='circle')),
row=1, col=1
)
# HRV plot
hrv = getHRV(beattimes)
hrv_time = beattimes[1:len(hrv)+1]
fig.add_trace(
go.Scatter(x=hrv_time, y=hrv, mode='lines', name='HRV', line=dict(color='green', width=1)),
row=3, col=1
)
# Spectrogram plot
n_fft = 2048 # You can adjust this value
hop_length = n_fft // 4 # You can adjust this value
D = librosa.stft(audiodata, n_fft=n_fft, hop_length=hop_length)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
# Calculate the correct time array for the spectrogram
spec_times = librosa.times_like(S_db, sr=sr, hop_length=hop_length)
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
fig.add_trace(
go.Heatmap(z=S_db, x=spec_times, y=freqs, colorscale='Viridis',
zmin=S_db.min(), zmax=S_db.max(), colorbar=dict(title='Magnitude (dB)')),
row=2, col=1
)
# Update layout
fig.update_layout(
height=1000,
title_text=filename,
showlegend=False
)
fig.update_xaxes(title_text="Time (s)", row=2, col=1)
fig.update_xaxes(range=[0, len(audiodata)/sr], row=2, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
fig.update_yaxes(title_text="HRV", row=3, col=1)
fig.update_yaxes(title_text="Frequency (Hz)", type="log", row=2, col=1)
return fig
def analyze_single(audio:gr.Audio):
# Extract audio data and sample rate
filepath = audio
filename = filepath.split("/")[-1]
sr, audiodata = u.getaudiodata(filepath)
# Now you have:
# - audiodata: a 1D numpy array containing the audio samples
# - sr: the sample rate of the audio
# Your analysis code goes here
# For example, you could print basic information:
# print(f"Audio length: {len(audiodata) / sr:.2f} seconds")
# print(f"Sample rate: {sr} Hz")
zcr = librosa.feature.zero_crossing_rate(audiodata)[0]
# print(f"Mean Zero Crossing Rate: {np.mean(zcr):.4f}")
# Calculate RMS Energy
rms = librosa.feature.rms(y=audiodata)[0]
# print(f"Mean RMS Energy: {np.mean(rms):.4f}")
tempo, beattimes = u.getBeats(audiodata, sr)
spectogram_wave = plotCombined(audiodata, sr, filename)
#beats_histogram = plotbeatscatter(tempo[0], beattimes)
# Add the new average heartbeat analysis
avg_beat_plot, avg_beat = create_average_heartbeat(audiodata, sr)
# Calculate some statistics about the average beat
avg_beat_duration = len(avg_beat) / sr
avg_beat_energy = np.sum(np.square(avg_beat))
# Return your analysis results
results = f"""
Average Heartbeat Analysis:
- Duration: {avg_beat_duration:.3f} seconds
- Energy: {avg_beat_energy:.3f}
- Audio length: {len(audiodata) / sr:.2f} seconds
- Sample rate: {sr} Hz
- Mean Zero Crossing Rate: {np.mean(zcr):.4f}
- Mean RMS Energy: {np.mean(rms):.4f}
- Tempo: {tempo[0]:.4f}
- Beats: {beattimes}
- Beat durations: {np.diff(beattimes)}
- Mean Beat Duration: {np.mean(np.diff(beattimes)):.4f}
"""
return results, spectogram_wave, avg_beat_plot
#-----------------------------------------------
# ANALYSIS FUNCTIONS
#-----------------------------------------------
# - [ ] Berechnungen Pro Segement:
# - [ ] RMS Energy
# - [ ] Frequenzen
# - [ ] Dauer
# - [ ] S2 - wenn möglich
# - [ ] Dauer S1 bis S2 (S1)
# - [ ] Dauer S2 bis S1 (S2)
# - [ ] Visualisierungen pro Datei:
# - [ ] Waveform
# - [ ] Spectogram
# - [ ] HRV
# - [ ] Avg. Heartbeat Waveform (fixe y-achse)
# - [ ] Alle Segmente als Waveform übereinanderlegen (fixe y-achse +-0.05)
# - [ ] Daten Exportierbar machen
# - [ ] Einheiten für (RMS Energy, Energy)
# - [ ] wichtige Einheiten (Energy, RMS Energy, Sample Rate, Audio length, Beats, Beats durations)
def get_visualizations(beattimes_table: str, cleanedaudio: gr.Audio):
df = mdpd.from_md(beattimes_table)
df['Beattimes'] = df['Beattimes'].astype(float)
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int)
sr, audiodata = cleanedaudio
segment_metrics = u.compute_segment_metrics(df, sr, audiodata)
audiodata = audiodata.astype(np.float32) / 32768.0
# Create figure with secondary y-axes
fig = make_subplots(
rows=5, cols=1,
subplot_titles=('Waveform', 'Spectrogram', 'Heart Rate Variability',
'Average Heartbeat Waveform', 'Overlaid Segments'),
vertical_spacing=0.1,
row_heights=[0.2, 0.2, 0.2, 0.2, 0.2]
)
# 1. Waveform
time = np.arange(len(audiodata)) / sr
fig.add_trace(
go.Scatter(x=time, y=audiodata, name='Waveform', line=dict(color='blue', width=1)),
row=1, col=1
)
# 2. Spectrogram
D = librosa.stft(audiodata)
frequencies = librosa.fft_frequencies(sr=sr) # Get frequency values for y-axis
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
times = librosa.times_like(S_db, sr=sr) # Get time values for x-axis
# Find index corresponding to 200 Hz
freq_mask = frequencies <= 1000
S_db_cropped = S_db[freq_mask]
frequencies_cropped = frequencies[freq_mask]
fig.add_trace(
go.Heatmap(
z=S_db_cropped,
x=times,
y=frequencies_cropped, # Add frequencies to y-axis
colorscale='Viridis',
name='Spectrogram'
),
row=2, col=1
)
# 3. HRV (Heart Rate Variability)
s1_durations = []
s2_durations = []
for segment in segment_metrics:
if segment['s1_to_s2_duration']:
s1_durations.extend(segment['s1_to_s2_duration'])
if segment['s2_to_s1_duration']:
s2_durations.extend(segment['s2_to_s1_duration'])
# Compute HRV metrics
time, hrv_values, _ = u.compute_hrv(s1_durations, s2_durations, sr)
# Add HRV trace to the third subplot
fig.add_trace(
go.Scatter(
x=time,
y=hrv_values,
name='HRV (RMSSD)',
line=dict(color='blue', width=1.5),
hovertemplate='Time: %{x:.1f}s
HRV: %{y:.1f}ms'
),
row=3, col=1
)
# 4. Average Heartbeat Waveform
max_len = max(len(metric['segment']) for metric in segment_metrics)
aligned_segments = []
for metric in segment_metrics:
segment = metric['segment']
segment = segment.astype(np.float32) / 32768.0
padded = np.pad(segment, (0, max_len - len(segment)))
aligned_segments.append(padded)
avg_waveform = np.mean(aligned_segments, axis=0)
time_avg = np.arange(len(avg_waveform)) / sr
fig.add_trace(
go.Scatter(x=time_avg, y=avg_waveform, name='Average Heartbeat',
line=dict(color='green', width=1)),
row=4, col=1
)
# 5. Overlaid Segments
colors = [
'#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
'#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd'
]
# Then in the loop for overlaid segments:
for i, metric in enumerate(segment_metrics):
segment = metric['segment']
segment = segment.astype(np.float32) / 32768.0
time_segment = np.arange(len(segment)) / sr
fig.add_trace(
go.Scatter(
x=time_segment,
y=segment,
name=f'Segment {i+1}',
opacity=0.3,
line=dict(color=colors[i % len(colors)], width=1)
),
row=5, col=1
)
# Update layout
fig.update_layout(
height=1500,
showlegend=False,
title_text="",
plot_bgcolor='white',
paper_bgcolor='white'
)
# # Update layout for the HRV subplot
# fig.update_yaxes(title_text="Heart Rate (BPM)",
# overlaying='y',
# side='right',
# row=3, col=1)
# Update y-axes for fixed scales where needed
# fig.update_yaxes(range=[-0.05, 0.05], row=5, col=1) # Fixed y-axis for overlaid segments
fig.update_yaxes(title_text="Amplitude", row=1, col=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Frequency (Hz)", row=2, col=1)
fig.update_yaxes(title_text="Duration (s)", row=3, col=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Amplitude", row=4, col=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Amplitude", row=5, col=1, gridcolor='lightgray')
# Update x-axes
fig.update_xaxes(title_text="Time (s)", row=1, col=1, gridcolor='lightgray')
fig.update_xaxes(title_text="Time (s)", row=2, col=1)
fig.update_xaxes(title_text="Time (s)", row=3, col=1, gridcolor='lightgray')
fig.update_xaxes(title_text="Time (s)", row=4, col=1, gridcolor='lightgray')
fig.update_xaxes(title_text="Time (s)", row=5, col=1, gridcolor='lightgray')
return fig
def download_all(beattimes_table:str, cleanedaudio:gr.Audio):
df = mdpd.from_md(beattimes_table)
df['Beattimes'] = df['Beattimes'].astype(float)
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int)
sr, audiodata = cleanedaudio
segment_metrics = u.compute_segment_metrics(df, sr, audiodata)
downloaddf = pd.DataFrame(segment_metrics)
# Convert numpy floats to regular floats
downloaddf['rms_energy'] = downloaddf['rms_energy'].astype(float)
downloaddf['mean_frequency'] = downloaddf['mean_frequency'].astype(float)
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, "segment_metrics.csv")
downloaddf.to_csv(temp_path, index=False)
return temp_path
#-----------------------------------------------
#-----------------------------------------------
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2
def getBeatsv2(audio:gr.Audio):
sr, audiodata = u.getaudiodata(audio)
_, beattimes, audiodata = u.getBeats(audiodata, sr)
beattimes_table = pd.DataFrame(data={"Beattimes":beattimes})
feature_array = u.find_s1s2(beattimes_table)
featuredf = pd.DataFrame(
data=feature_array,
columns=[
"Beattimes",
"S1 to S2",
"S2 to S1",
"Label (S1=1/S2=0)"]
)
# Create boolean masks for each label
mask_ones = feature_array[:, 3] == 1
mask_zeros = feature_array[:, 3] == 0
# Extract time/positions using the masks
times_label_one = feature_array[mask_ones, 0]
times_label_zero = feature_array[mask_zeros, 0]
fig = u.plotBeattimes(times_label_one, audiodata, sr, times_label_zero)
featuredf = featuredf.drop(columns=["S1 to S2", "S2 to S1"])
return fig, featuredf.to_markdown(), (sr, audiodata)
def updateBeatsv2(audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure:
sr, audiodata = u.getaudiodata(audio)
if uploadeddf != None:
beattimes_table = pd.read_csv(
filepath_or_buffer=uploadeddf,
sep=";",
decimal=",",
encoding="utf-8-sig")
# Drop rows where all columns are NaN (empty)
beattimes_table = beattimes_table.dropna()
# Reset the index after dropping rows
beattimes_table = beattimes_table.reset_index(drop=True)
# Convert the 'Beattimes' column to float
# Handle both string and numeric values in the Beattimes column
if beattimes_table['Beattimes'].dtype == 'object':
# If strings, replace commas with dots
beattimes_table['Beattimes'] = beattimes_table['Beattimes'].str.replace(',', '.').astype(float)
else:
# If already numeric, just ensure float type
beattimes_table['Beattimes'] = beattimes_table['Beattimes'].astype(float)
# Check if the column "Label (S1=0/S2=1)" exists and rename it
if "Label (S1=0/S2=1)" in beattimes_table.columns:
beattimes_table = beattimes_table.rename(columns={"Label (S1=0/S2=1)": "Label (S1=1/S2=0)"})
else:
raise FileNotFoundError("No file uploaded")
s1_times = beattimes_table[beattimes_table["Label (S1=1/S2=0)"] == 0]["Beattimes"].to_numpy()
s2_times = beattimes_table[beattimes_table["Label (S1=1/S2=0)"] == 1]["Beattimes"].to_numpy()
fig = u.plotBeattimes(s1_times, audiodata, sr, s2_times)
return fig, beattimes_table.to_markdown()
def download_df (beattimes_table: str):
df = mdpd.from_md(beattimes_table)
df['Beattimes'] = df['Beattimes'].astype(float)
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int)
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, "beattimes.csv")
df.to_csv(
index=False,
columns=["Beattimes", "Label (S1=1/S2=0)"],
path_or_buf=temp_path,
sep=";",
decimal=",",
encoding="utf-8-sig")
return temp_path
with gr.Blocks() as app:
gr.Markdown("# Heartbeat")
gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios")
audiofile = gr.Audio(
type="filepath",
label="Upload the Audio of a Heartbeat",
sources="upload")
with gr.Tab("Preprocessing"):
getBeatsbtn = gr.Button("get Beats")
cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True)
beats_wave_plot = gr.Plot()
with gr.Row():
with gr.Column():
beattimes_table = gr.Markdown()
with gr.Column():
csv_download = gr.DownloadButton()
updateBeatsbtn = gr.Button("update Beats")
uploadDF = gr.File(
file_count="single",
file_types=[".csv"],
label="upload a csv",
height=25
)
csv_download.click(download_df, inputs=[beattimes_table], outputs=[csv_download])
getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table, cleanedaudio])
updateBeatsbtn.click(updateBeatsv2, inputs=[audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_table])
gr.Examples(
examples=example_files,
inputs=audiofile,
fn=getBeatsv2,
cache_examples=False
)
with gr.Tab("Analysis"):
gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'")
analyzebtn = gr.Button("Analyze Audio")
plot = gr.Plot()
download_btn = gr.DownloadButton()
analyzebtn.click(get_visualizations, inputs=[beattimes_table, cleanedaudio], outputs=[plot])
download_btn.click(download_all, inputs=[beattimes_table, cleanedaudio], outputs=[download_btn])
app.launch()