Spaces:
Sleeping
Sleeping
from plotly.subplots import make_subplots | |
from scipy.signal import find_peaks, butter, filtfilt | |
import plotly.graph_objects as go | |
from io import StringIO | |
import pandas as pd | |
import gradio as gr | |
import numpy as np | |
import itertools | |
import tempfile | |
import librosa | |
import random | |
import mdpd | |
import os | |
#from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2 | |
import utils as u | |
example_dir = "Examples" | |
example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))] | |
all_pairs = list(itertools.combinations(example_files, 2)) | |
random.shuffle(all_pairs) | |
example_pairs = [list(pair) for pair in all_pairs[:25]] | |
#----------------------------------------------- | |
# PROCESSING FUNCTIONS | |
#----------------------------------------------- | |
def getHRV(beattimes: np.ndarray) -> np.ndarray: | |
# Calculate instantaneous heart rate | |
instantaneous_hr = 60 * np.diff(beattimes) | |
# # Calculate moving average heart rate (e.g., over 10 beats) | |
# window_size = 10 | |
# moving_avg_hr = np.convolve(instantaneous_hr, np.ones(window_size), 'valid') / window_size | |
# # Calculate heart rate variability as the difference from the moving average | |
# hrv = instantaneous_hr[window_size-1:] - moving_avg_hr | |
return instantaneous_hr | |
def create_average_heartbeat(audiodata, sr): | |
# 1. Detect individual heartbeats | |
onset_env = librosa.onset.onset_strength(y=audiodata, sr=sr) | |
peaks, _ = find_peaks(onset_env, distance=sr//2) # Assume at least 0.5s between beats | |
# 2. Extract individual heartbeats | |
beat_length = sr # Assume 1 second for each beat | |
beats = [] | |
for peak in peaks: | |
if peak + beat_length < len(audiodata): | |
beat = audiodata[peak:peak+beat_length] | |
beats.append(beat) | |
# 3. Align and average the beats | |
if beats: | |
avg_beat = np.mean(beats, axis=0) | |
else: | |
avg_beat = np.array([]) | |
# 4. Create a Plotly figure of the average heartbeat | |
time = np.arange(len(avg_beat)) / sr | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=time, y=avg_beat, mode='lines', name='Average Beat')) | |
fig.update_layout( | |
title='Average Heartbeat', | |
xaxis_title='Time (s)', | |
yaxis_title='Amplitude' | |
) | |
return fig, avg_beat | |
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS | |
def plotCombined(audiodata, sr, filename): | |
# Get beat times | |
tempo, beattimes = u.getBeats(audiodata, sr) | |
# Create subplots | |
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.1, | |
subplot_titles=('Audio Waveform', 'Spectrogram', 'Heart Rate Variability')) | |
# Time array for the full audio | |
time = (np.arange(0, len(audiodata)) / sr) * 2 | |
# Waveform plot | |
fig.add_trace( | |
go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1)), | |
row=1, col=1 | |
) | |
# Add beat markers | |
beat_amplitudes = np.interp(beattimes, time, audiodata) | |
fig.add_trace( | |
go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats', | |
marker=dict(color='red', size=8, symbol='circle')), | |
row=1, col=1 | |
) | |
# HRV plot | |
hrv = getHRV(beattimes) | |
hrv_time = beattimes[1:len(hrv)+1] | |
fig.add_trace( | |
go.Scatter(x=hrv_time, y=hrv, mode='lines', name='HRV', line=dict(color='green', width=1)), | |
row=3, col=1 | |
) | |
# Spectrogram plot | |
n_fft = 2048 # You can adjust this value | |
hop_length = n_fft // 4 # You can adjust this value | |
D = librosa.stft(audiodata, n_fft=n_fft, hop_length=hop_length) | |
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
# Calculate the correct time array for the spectrogram | |
spec_times = librosa.times_like(S_db, sr=sr, hop_length=hop_length) | |
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) | |
fig.add_trace( | |
go.Heatmap(z=S_db, x=spec_times, y=freqs, colorscale='Viridis', | |
zmin=S_db.min(), zmax=S_db.max(), colorbar=dict(title='Magnitude (dB)')), | |
row=2, col=1 | |
) | |
# Update layout | |
fig.update_layout( | |
height=1000, | |
title_text=filename, | |
showlegend=False | |
) | |
fig.update_xaxes(title_text="Time (s)", row=2, col=1) | |
fig.update_xaxes(range=[0, len(audiodata)/sr], row=2, col=1) | |
fig.update_yaxes(title_text="Amplitude", row=1, col=1) | |
fig.update_yaxes(title_text="HRV", row=3, col=1) | |
fig.update_yaxes(title_text="Frequency (Hz)", type="log", row=2, col=1) | |
return fig | |
def analyze_single(audio:gr.Audio): | |
# Extract audio data and sample rate | |
filepath = audio | |
filename = filepath.split("/")[-1] | |
sr, audiodata = u.getaudiodata(filepath) | |
# Now you have: | |
# - audiodata: a 1D numpy array containing the audio samples | |
# - sr: the sample rate of the audio | |
# Your analysis code goes here | |
# For example, you could print basic information: | |
# print(f"Audio length: {len(audiodata) / sr:.2f} seconds") | |
# print(f"Sample rate: {sr} Hz") | |
zcr = librosa.feature.zero_crossing_rate(audiodata)[0] | |
# print(f"Mean Zero Crossing Rate: {np.mean(zcr):.4f}") | |
# Calculate RMS Energy | |
rms = librosa.feature.rms(y=audiodata)[0] | |
# print(f"Mean RMS Energy: {np.mean(rms):.4f}") | |
tempo, beattimes = u.getBeats(audiodata, sr) | |
spectogram_wave = plotCombined(audiodata, sr, filename) | |
#beats_histogram = plotbeatscatter(tempo[0], beattimes) | |
# Add the new average heartbeat analysis | |
avg_beat_plot, avg_beat = create_average_heartbeat(audiodata, sr) | |
# Calculate some statistics about the average beat | |
avg_beat_duration = len(avg_beat) / sr | |
avg_beat_energy = np.sum(np.square(avg_beat)) | |
# Return your analysis results | |
results = f""" | |
Average Heartbeat Analysis: | |
- Duration: {avg_beat_duration:.3f} seconds | |
- Energy: {avg_beat_energy:.3f} | |
- Audio length: {len(audiodata) / sr:.2f} seconds | |
- Sample rate: {sr} Hz | |
- Mean Zero Crossing Rate: {np.mean(zcr):.4f} | |
- Mean RMS Energy: {np.mean(rms):.4f} | |
- Tempo: {tempo[0]:.4f} | |
- Beats: {beattimes} | |
- Beat durations: {np.diff(beattimes)} | |
- Mean Beat Duration: {np.mean(np.diff(beattimes)):.4f} | |
""" | |
return results, spectogram_wave, avg_beat_plot | |
#----------------------------------------------- | |
# ANALYSIS FUNCTIONS | |
#----------------------------------------------- | |
# - [ ] Berechnungen Pro Segement: | |
# - [ ] RMS Energy | |
# - [ ] Frequenzen | |
# - [ ] Dauer | |
# - [ ] S2 - wenn möglich | |
# - [ ] Dauer S1 bis S2 (S1) | |
# - [ ] Dauer S2 bis S1 (S2) | |
# - [ ] Visualisierungen pro Datei: | |
# - [ ] Waveform | |
# - [ ] Spectogram | |
# - [ ] HRV | |
# - [ ] Avg. Heartbeat Waveform (fixe y-achse) | |
# - [ ] Alle Segmente als Waveform übereinanderlegen (fixe y-achse +-0.05) | |
# - [ ] Daten Exportierbar machen | |
# - [ ] Einheiten für (RMS Energy, Energy) | |
# - [ ] wichtige Einheiten (Energy, RMS Energy, Sample Rate, Audio length, Beats, Beats durations) | |
def get_visualizations(beattimes_table: str, cleanedaudio: gr.Audio): | |
df = mdpd.from_md(beattimes_table) | |
df['Beattimes'] = df['Beattimes'].astype(float) | |
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int) | |
sr, audiodata = cleanedaudio | |
segment_metrics = u.compute_segment_metrics(df, sr, audiodata) | |
audiodata = audiodata.astype(np.float32) / 32768.0 | |
# Create figure with secondary y-axes | |
fig = make_subplots( | |
rows=5, cols=1, | |
subplot_titles=('Waveform', 'Spectrogram', 'Heart Rate Variability', | |
'Average Heartbeat Waveform', 'Overlaid Segments'), | |
vertical_spacing=0.1, | |
row_heights=[0.2, 0.2, 0.2, 0.2, 0.2] | |
) | |
# 1. Waveform | |
time = np.arange(len(audiodata)) / sr | |
fig.add_trace( | |
go.Scatter(x=time, y=audiodata, name='Waveform', line=dict(color='blue', width=1)), | |
row=1, col=1 | |
) | |
# 2. Spectrogram | |
D = librosa.stft(audiodata) | |
frequencies = librosa.fft_frequencies(sr=sr) # Get frequency values for y-axis | |
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
times = librosa.times_like(S_db, sr=sr) # Get time values for x-axis | |
# Find index corresponding to 200 Hz | |
freq_mask = frequencies <= 1000 | |
S_db_cropped = S_db[freq_mask] | |
frequencies_cropped = frequencies[freq_mask] | |
fig.add_trace( | |
go.Heatmap( | |
z=S_db_cropped, | |
x=times, | |
y=frequencies_cropped, # Add frequencies to y-axis | |
colorscale='Viridis', | |
name='Spectrogram' | |
), | |
row=2, col=1 | |
) | |
# 3. HRV (Heart Rate Variability) | |
s1_durations = [] | |
s2_durations = [] | |
for segment in segment_metrics: | |
if segment['s1_to_s2_duration']: | |
s1_durations.extend(segment['s1_to_s2_duration']) | |
if segment['s2_to_s1_duration']: | |
s2_durations.extend(segment['s2_to_s1_duration']) | |
# Compute HRV metrics | |
time, hrv_values, _ = u.compute_hrv(s1_durations, s2_durations, sr) | |
# Add HRV trace to the third subplot | |
fig.add_trace( | |
go.Scatter( | |
x=time, | |
y=hrv_values, | |
name='HRV (RMSSD)', | |
line=dict(color='blue', width=1.5), | |
hovertemplate='Time: %{x:.1f}s<br>HRV: %{y:.1f}ms<extra></extra>' | |
), | |
row=3, col=1 | |
) | |
# 4. Average Heartbeat Waveform | |
max_len = max(len(metric['segment']) for metric in segment_metrics) | |
aligned_segments = [] | |
for metric in segment_metrics: | |
segment = metric['segment'] | |
segment = segment.astype(np.float32) / 32768.0 | |
padded = np.pad(segment, (0, max_len - len(segment))) | |
aligned_segments.append(padded) | |
avg_waveform = np.mean(aligned_segments, axis=0) | |
time_avg = np.arange(len(avg_waveform)) / sr | |
fig.add_trace( | |
go.Scatter(x=time_avg, y=avg_waveform, name='Average Heartbeat', | |
line=dict(color='green', width=1)), | |
row=4, col=1 | |
) | |
# 5. Overlaid Segments | |
colors = [ | |
'#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', | |
'#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd' | |
] | |
# Then in the loop for overlaid segments: | |
for i, metric in enumerate(segment_metrics): | |
segment = metric['segment'] | |
segment = segment.astype(np.float32) / 32768.0 | |
time_segment = np.arange(len(segment)) / sr | |
fig.add_trace( | |
go.Scatter( | |
x=time_segment, | |
y=segment, | |
name=f'Segment {i+1}', | |
opacity=0.3, | |
line=dict(color=colors[i % len(colors)], width=1) | |
), | |
row=5, col=1 | |
) | |
# Update layout | |
fig.update_layout( | |
height=1500, | |
showlegend=False, | |
title_text="", | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
# # Update layout for the HRV subplot | |
# fig.update_yaxes(title_text="Heart Rate (BPM)", | |
# overlaying='y', | |
# side='right', | |
# row=3, col=1) | |
# Update y-axes for fixed scales where needed | |
# fig.update_yaxes(range=[-0.05, 0.05], row=5, col=1) # Fixed y-axis for overlaid segments | |
fig.update_yaxes(title_text="Amplitude", row=1, col=1, gridcolor='lightgray') | |
fig.update_yaxes(title_text="Frequency (Hz)", row=2, col=1) | |
fig.update_yaxes(title_text="Duration (s)", row=3, col=1, gridcolor='lightgray') | |
fig.update_yaxes(title_text="Amplitude", row=4, col=1, gridcolor='lightgray') | |
fig.update_yaxes(title_text="Amplitude", row=5, col=1, gridcolor='lightgray') | |
# Update x-axes | |
fig.update_xaxes(title_text="Time (s)", row=1, col=1, gridcolor='lightgray') | |
fig.update_xaxes(title_text="Time (s)", row=2, col=1) | |
fig.update_xaxes(title_text="Time (s)", row=3, col=1, gridcolor='lightgray') | |
fig.update_xaxes(title_text="Time (s)", row=4, col=1, gridcolor='lightgray') | |
fig.update_xaxes(title_text="Time (s)", row=5, col=1, gridcolor='lightgray') | |
return fig | |
def download_all(beattimes_table:str, cleanedaudio:gr.Audio): | |
df = mdpd.from_md(beattimes_table) | |
df['Beattimes'] = df['Beattimes'].astype(float) | |
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int) | |
sr, audiodata = cleanedaudio | |
segment_metrics = u.compute_segment_metrics(df, sr, audiodata) | |
downloaddf = pd.DataFrame(segment_metrics) | |
# Convert numpy floats to regular floats | |
downloaddf['rms_energy'] = downloaddf['rms_energy'].astype(float) | |
downloaddf['mean_frequency'] = downloaddf['mean_frequency'].astype(float) | |
temp_dir = tempfile.gettempdir() | |
temp_path = os.path.join(temp_dir, "segment_metrics.csv") | |
downloaddf.to_csv(temp_path, index=False) | |
return temp_path | |
#----------------------------------------------- | |
#----------------------------------------------- | |
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2 | |
def getBeatsv2(audio:gr.Audio): | |
sr, audiodata = u.getaudiodata(audio) | |
_, beattimes, audiodata = u.getBeats(audiodata, sr) | |
beattimes_table = pd.DataFrame(data={"Beattimes":beattimes}) | |
feature_array = u.find_s1s2(beattimes_table) | |
featuredf = pd.DataFrame( | |
data=feature_array, | |
columns=[ | |
"Beattimes", | |
"S1 to S2", | |
"S2 to S1", | |
"Label (S1=1/S2=0)"] | |
) | |
# Create boolean masks for each label | |
mask_ones = feature_array[:, 3] == 1 | |
mask_zeros = feature_array[:, 3] == 0 | |
# Extract time/positions using the masks | |
times_label_one = feature_array[mask_ones, 0] | |
times_label_zero = feature_array[mask_zeros, 0] | |
fig = u.plotBeattimes(times_label_one, audiodata, sr, times_label_zero) | |
featuredf = featuredf.drop(columns=["S1 to S2", "S2 to S1"]) | |
return fig, featuredf.to_markdown(), (sr, audiodata) | |
def updateBeatsv2(audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure: | |
sr, audiodata = u.getaudiodata(audio) | |
if uploadeddf != None: | |
beattimes_table = pd.read_csv( | |
filepath_or_buffer=uploadeddf, | |
sep=";", | |
decimal=",", | |
encoding="utf-8-sig") | |
# Drop rows where all columns are NaN (empty) | |
beattimes_table = beattimes_table.dropna() | |
# Reset the index after dropping rows | |
beattimes_table = beattimes_table.reset_index(drop=True) | |
# Convert the 'Beattimes' column to float | |
# Handle both string and numeric values in the Beattimes column | |
if beattimes_table['Beattimes'].dtype == 'object': | |
# If strings, replace commas with dots | |
beattimes_table['Beattimes'] = beattimes_table['Beattimes'].str.replace(',', '.').astype(float) | |
else: | |
# If already numeric, just ensure float type | |
beattimes_table['Beattimes'] = beattimes_table['Beattimes'].astype(float) | |
# Check if the column "Label (S1=0/S2=1)" exists and rename it | |
if "Label (S1=0/S2=1)" in beattimes_table.columns: | |
beattimes_table = beattimes_table.rename(columns={"Label (S1=0/S2=1)": "Label (S1=1/S2=0)"}) | |
else: | |
raise FileNotFoundError("No file uploaded") | |
s1_times = beattimes_table[beattimes_table["Label (S1=1/S2=0)"] == 0]["Beattimes"].to_numpy() | |
s2_times = beattimes_table[beattimes_table["Label (S1=1/S2=0)"] == 1]["Beattimes"].to_numpy() | |
fig = u.plotBeattimes(s1_times, audiodata, sr, s2_times) | |
return fig, beattimes_table.to_markdown() | |
def download_df (beattimes_table: str): | |
df = mdpd.from_md(beattimes_table) | |
df['Beattimes'] = df['Beattimes'].astype(float) | |
df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int) | |
temp_dir = tempfile.gettempdir() | |
temp_path = os.path.join(temp_dir, "beattimes.csv") | |
df.to_csv( | |
index=False, | |
columns=["Beattimes", "Label (S1=1/S2=0)"], | |
path_or_buf=temp_path, | |
sep=";", | |
decimal=",", | |
encoding="utf-8-sig") | |
return temp_path | |
with gr.Blocks() as app: | |
gr.Markdown("# Heartbeat") | |
gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios") | |
audiofile = gr.Audio( | |
type="filepath", | |
label="Upload the Audio of a Heartbeat", | |
sources="upload") | |
with gr.Tab("Preprocessing"): | |
getBeatsbtn = gr.Button("get Beats") | |
cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True) | |
beats_wave_plot = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
beattimes_table = gr.Markdown() | |
with gr.Column(): | |
csv_download = gr.DownloadButton() | |
updateBeatsbtn = gr.Button("update Beats") | |
uploadDF = gr.File( | |
file_count="single", | |
file_types=[".csv"], | |
label="upload a csv", | |
height=25 | |
) | |
csv_download.click(download_df, inputs=[beattimes_table], outputs=[csv_download]) | |
getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table, cleanedaudio]) | |
updateBeatsbtn.click(updateBeatsv2, inputs=[audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_table]) | |
gr.Examples( | |
examples=example_files, | |
inputs=audiofile, | |
fn=getBeatsv2, | |
cache_examples=False | |
) | |
with gr.Tab("Analysis"): | |
gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'") | |
analyzebtn = gr.Button("Analyze Audio") | |
plot = gr.Plot() | |
download_btn = gr.DownloadButton() | |
analyzebtn.click(get_visualizations, inputs=[beattimes_table, cleanedaudio], outputs=[plot]) | |
download_btn.click(download_all, inputs=[beattimes_table, cleanedaudio], outputs=[download_btn]) | |
app.launch() |