Spaces:
Running
Running
from plotly.subplots import make_subplots | |
from scipy.signal import find_peaks, butter, filtfilt | |
import plotly.graph_objects as go | |
from io import StringIO | |
import pandas as pd | |
import gradio as gr | |
import numpy as np | |
import itertools | |
import tempfile | |
import librosa | |
import random | |
import mdpd | |
import os | |
from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2 | |
example_dir = "Examples" | |
example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))] | |
all_pairs = list(itertools.combinations(example_files, 2)) | |
random.shuffle(all_pairs) | |
example_pairs = [list(pair) for pair in all_pairs[:25]] | |
def getHRV(beattimes: np.ndarray) -> np.ndarray: | |
# Calculate instantaneous heart rate | |
instantaneous_hr = 60 * np.diff(beattimes) | |
# # Calculate moving average heart rate (e.g., over 10 beats) | |
# window_size = 10 | |
# moving_avg_hr = np.convolve(instantaneous_hr, np.ones(window_size), 'valid') / window_size | |
# # Calculate heart rate variability as the difference from the moving average | |
# hrv = instantaneous_hr[window_size-1:] - moving_avg_hr | |
return instantaneous_hr | |
def create_average_heartbeat(audiodata, sr): | |
# 1. Detect individual heartbeats | |
onset_env = librosa.onset.onset_strength(y=audiodata, sr=sr) | |
peaks, _ = find_peaks(onset_env, distance=sr//2) # Assume at least 0.5s between beats | |
# 2. Extract individual heartbeats | |
beat_length = sr # Assume 1 second for each beat | |
beats = [] | |
for peak in peaks: | |
if peak + beat_length < len(audiodata): | |
beat = audiodata[peak:peak+beat_length] | |
beats.append(beat) | |
# 3. Align and average the beats | |
if beats: | |
avg_beat = np.mean(beats, axis=0) | |
else: | |
avg_beat = np.array([]) | |
# 4. Create a Plotly figure of the average heartbeat | |
time = np.arange(len(avg_beat)) / sr | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=time, y=avg_beat, mode='lines', name='Average Beat')) | |
fig.update_layout( | |
title='Average Heartbeat', | |
xaxis_title='Time (s)', | |
yaxis_title='Amplitude' | |
) | |
return fig, avg_beat | |
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS | |
def plotCombined(audiodata, sr, filename): | |
# Get beat times | |
tempo, beattimes = getBeats(audiodata, sr) | |
# Create subplots | |
fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.1, | |
subplot_titles=('Audio Waveform', 'Spectrogram', 'Heart Rate Variability')) | |
# Time array for the full audio | |
time = (np.arange(0, len(audiodata)) / sr) * 2 | |
# Waveform plot | |
fig.add_trace( | |
go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1)), | |
row=1, col=1 | |
) | |
# Add beat markers | |
beat_amplitudes = np.interp(beattimes, time, audiodata) | |
fig.add_trace( | |
go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats', | |
marker=dict(color='red', size=8, symbol='circle')), | |
row=1, col=1 | |
) | |
# HRV plot | |
hrv = getHRV(beattimes) | |
hrv_time = beattimes[1:len(hrv)+1] | |
fig.add_trace( | |
go.Scatter(x=hrv_time, y=hrv, mode='lines', name='HRV', line=dict(color='green', width=1)), | |
row=3, col=1 | |
) | |
# Spectrogram plot | |
n_fft = 2048 # You can adjust this value | |
hop_length = n_fft // 4 # You can adjust this value | |
D = librosa.stft(audiodata, n_fft=n_fft, hop_length=hop_length) | |
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
# Calculate the correct time array for the spectrogram | |
spec_times = librosa.times_like(S_db, sr=sr, hop_length=hop_length) | |
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) | |
fig.add_trace( | |
go.Heatmap(z=S_db, x=spec_times, y=freqs, colorscale='Viridis', | |
zmin=S_db.min(), zmax=S_db.max(), colorbar=dict(title='Magnitude (dB)')), | |
row=2, col=1 | |
) | |
# Update layout | |
fig.update_layout( | |
height=1000, | |
title_text=filename, | |
showlegend=False | |
) | |
fig.update_xaxes(title_text="Time (s)", row=2, col=1) | |
fig.update_xaxes(range=[0, len(audiodata)/sr], row=2, col=1) | |
fig.update_yaxes(title_text="Amplitude", row=1, col=1) | |
fig.update_yaxes(title_text="HRV", row=3, col=1) | |
fig.update_yaxes(title_text="Frequency (Hz)", type="log", row=2, col=1) | |
return fig | |
def analyze_single(audio:gr.Audio): | |
# Extract audio data and sample rate | |
filepath = audio | |
filename = filepath.split("/")[-1] | |
sr, audiodata = getaudiodata(filepath) | |
# Now you have: | |
# - audiodata: a 1D numpy array containing the audio samples | |
# - sr: the sample rate of the audio | |
# Your analysis code goes here | |
# For example, you could print basic information: | |
# print(f"Audio length: {len(audiodata) / sr:.2f} seconds") | |
# print(f"Sample rate: {sr} Hz") | |
zcr = librosa.feature.zero_crossing_rate(audiodata)[0] | |
# print(f"Mean Zero Crossing Rate: {np.mean(zcr):.4f}") | |
# Calculate RMS Energy | |
rms = librosa.feature.rms(y=audiodata)[0] | |
# print(f"Mean RMS Energy: {np.mean(rms):.4f}") | |
tempo, beattimes = getBeats(audiodata, sr) | |
spectogram_wave = plotCombined(audiodata, sr, filename) | |
#beats_histogram = plotbeatscatter(tempo[0], beattimes) | |
# Add the new average heartbeat analysis | |
avg_beat_plot, avg_beat = create_average_heartbeat(audiodata, sr) | |
# Calculate some statistics about the average beat | |
avg_beat_duration = len(avg_beat) / sr | |
avg_beat_energy = np.sum(np.square(avg_beat)) | |
# Return your analysis results | |
results = f""" | |
Average Heartbeat Analysis: | |
- Duration: {avg_beat_duration:.3f} seconds | |
- Energy: {avg_beat_energy:.3f} | |
- Audio length: {len(audiodata) / sr:.2f} seconds | |
- Sample rate: {sr} Hz | |
- Mean Zero Crossing Rate: {np.mean(zcr):.4f} | |
- Mean RMS Energy: {np.mean(rms):.4f} | |
- Tempo: {tempo[0]:.4f} | |
- Beats: {beattimes} | |
- Beat durations: {np.diff(beattimes)} | |
- Mean Beat Duration: {np.mean(np.diff(beattimes)):.4f} | |
""" | |
return results, spectogram_wave, avg_beat_plot | |
#----------------------------------------------- | |
#----------------------------------------------- | |
# HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2 | |
def getBeatsv2(audio:gr.Audio): | |
sr, audiodata = getaudiodata(audio) | |
_, beattimes, audiodata = getBeats(audiodata, sr) | |
beattimes_table = pd.DataFrame(data={"Beattimes":beattimes}) | |
feature_array = find_s1s2(beattimes_table) | |
featuredf = pd.DataFrame( | |
data=feature_array, | |
columns=[ | |
"Beattimes", | |
"S1 to S2", | |
"S2 to S1", | |
"Label (S1=0/S2=1)"] | |
) | |
# Create boolean masks for each label | |
mask_ones = feature_array[:, 3] == 1 | |
mask_zeros = feature_array[:, 3] == 0 | |
# Extract time/positions using the masks | |
times_label_one = feature_array[mask_ones, 0] | |
times_label_zero = feature_array[mask_zeros, 0] | |
fig = plotBeattimes(times_label_one, audiodata, sr, times_label_zero) | |
return fig, featuredf.to_markdown(), (sr, audiodata) | |
def updateBeatsv2(beattimes_table:gr.Markdown, audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure: | |
df = mdpd.from_md(df) | |
sr, audiodata = getaudiodata(audio) | |
if uploadeddf != None: | |
beattimes_table = pd.read_csv(uploadeddf) | |
s1_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 0]["Beattimes"].to_numpy() | |
s2_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 1]["Beattimes"].to_numpy() | |
fig = plotBeattimes(s1_times, audiodata, sr, s2_times) | |
return fig, beattimes_table.to_markdown() | |
def download_df (df: str): | |
df = mdpd.from_md(df) | |
print(df) | |
temp_dir = tempfile.gettempdir() | |
temp_path = os.path.join(temp_dir, "feature_data.csv") | |
df.to_csv(temp_path, index=False) | |
return temp_path | |
with gr.Blocks() as app: | |
gr.Markdown("# Heartbeat") | |
gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios") | |
audiofile = gr.Audio( | |
type="filepath", | |
label="Upload the Audio of a Heartbeat", | |
sources="upload") | |
with gr.Tab("Preprocessing"): | |
getBeatsbtn = gr.Button("get Beats") | |
cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True) | |
beats_wave_plot = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
beattimes_table = gr.Markdown() | |
with gr.Column(): | |
csv_download = gr.DownloadButton() | |
updateBeatsbtn = gr.Button("update Beats") | |
uploadDF = gr.File( | |
file_count="single", | |
file_types=[".csv"], | |
label="upload a csv", | |
height=25 | |
) | |
csv_download.click(download_df, inputs=[beattimes_table], outputs=[csv_download]) | |
getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table, cleanedaudio]) | |
updateBeatsbtn.click(updateBeatsv2, inputs=[beattimes_table, audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_table]) | |
gr.Examples( | |
examples=example_files, | |
inputs=audiofile, | |
fn=getBeatsv2, | |
cache_examples=False | |
) | |
with gr.Tab("Analysis"): | |
gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'") | |
app.launch() |