from plotly.subplots import make_subplots from scipy.signal import find_peaks, butter, filtfilt import plotly.graph_objects as go from io import StringIO import pandas as pd import gradio as gr import numpy as np import itertools import tempfile import librosa import random import mdpd import os from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2 example_dir = "Examples" example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))] all_pairs = list(itertools.combinations(example_files, 2)) random.shuffle(all_pairs) example_pairs = [list(pair) for pair in all_pairs[:25]] def getHRV(beattimes: np.ndarray) -> np.ndarray: # Calculate instantaneous heart rate instantaneous_hr = 60 * np.diff(beattimes) # # Calculate moving average heart rate (e.g., over 10 beats) # window_size = 10 # moving_avg_hr = np.convolve(instantaneous_hr, np.ones(window_size), 'valid') / window_size # # Calculate heart rate variability as the difference from the moving average # hrv = instantaneous_hr[window_size-1:] - moving_avg_hr return instantaneous_hr def create_average_heartbeat(audiodata, sr): # 1. Detect individual heartbeats onset_env = librosa.onset.onset_strength(y=audiodata, sr=sr) peaks, _ = find_peaks(onset_env, distance=sr//2) # Assume at least 0.5s between beats # 2. Extract individual heartbeats beat_length = sr # Assume 1 second for each beat beats = [] for peak in peaks: if peak + beat_length < len(audiodata): beat = audiodata[peak:peak+beat_length] beats.append(beat) # 3. Align and average the beats if beats: avg_beat = np.mean(beats, axis=0) else: avg_beat = np.array([]) # 4. Create a Plotly figure of the average heartbeat time = np.arange(len(avg_beat)) / sr fig = go.Figure() fig.add_trace(go.Scatter(x=time, y=avg_beat, mode='lines', name='Average Beat')) fig.update_layout( title='Average Heartbeat', xaxis_title='Time (s)', yaxis_title='Amplitude' ) return fig, avg_beat # HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS def plotCombined(audiodata, sr, filename): # Get beat times tempo, beattimes = getBeats(audiodata, sr) # Create subplots fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.1, subplot_titles=('Audio Waveform', 'Spectrogram', 'Heart Rate Variability')) # Time array for the full audio time = (np.arange(0, len(audiodata)) / sr) * 2 # Waveform plot fig.add_trace( go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1)), row=1, col=1 ) # Add beat markers beat_amplitudes = np.interp(beattimes, time, audiodata) fig.add_trace( go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats', marker=dict(color='red', size=8, symbol='circle')), row=1, col=1 ) # HRV plot hrv = getHRV(beattimes) hrv_time = beattimes[1:len(hrv)+1] fig.add_trace( go.Scatter(x=hrv_time, y=hrv, mode='lines', name='HRV', line=dict(color='green', width=1)), row=3, col=1 ) # Spectrogram plot n_fft = 2048 # You can adjust this value hop_length = n_fft // 4 # You can adjust this value D = librosa.stft(audiodata, n_fft=n_fft, hop_length=hop_length) S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) # Calculate the correct time array for the spectrogram spec_times = librosa.times_like(S_db, sr=sr, hop_length=hop_length) freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) fig.add_trace( go.Heatmap(z=S_db, x=spec_times, y=freqs, colorscale='Viridis', zmin=S_db.min(), zmax=S_db.max(), colorbar=dict(title='Magnitude (dB)')), row=2, col=1 ) # Update layout fig.update_layout( height=1000, title_text=filename, showlegend=False ) fig.update_xaxes(title_text="Time (s)", row=2, col=1) fig.update_xaxes(range=[0, len(audiodata)/sr], row=2, col=1) fig.update_yaxes(title_text="Amplitude", row=1, col=1) fig.update_yaxes(title_text="HRV", row=3, col=1) fig.update_yaxes(title_text="Frequency (Hz)", type="log", row=2, col=1) return fig def analyze_single(audio:gr.Audio): # Extract audio data and sample rate filepath = audio filename = filepath.split("/")[-1] sr, audiodata = getaudiodata(filepath) # Now you have: # - audiodata: a 1D numpy array containing the audio samples # - sr: the sample rate of the audio # Your analysis code goes here # For example, you could print basic information: # print(f"Audio length: {len(audiodata) / sr:.2f} seconds") # print(f"Sample rate: {sr} Hz") zcr = librosa.feature.zero_crossing_rate(audiodata)[0] # print(f"Mean Zero Crossing Rate: {np.mean(zcr):.4f}") # Calculate RMS Energy rms = librosa.feature.rms(y=audiodata)[0] # print(f"Mean RMS Energy: {np.mean(rms):.4f}") tempo, beattimes = getBeats(audiodata, sr) spectogram_wave = plotCombined(audiodata, sr, filename) #beats_histogram = plotbeatscatter(tempo[0], beattimes) # Add the new average heartbeat analysis avg_beat_plot, avg_beat = create_average_heartbeat(audiodata, sr) # Calculate some statistics about the average beat avg_beat_duration = len(avg_beat) / sr avg_beat_energy = np.sum(np.square(avg_beat)) # Return your analysis results results = f""" Average Heartbeat Analysis: - Duration: {avg_beat_duration:.3f} seconds - Energy: {avg_beat_energy:.3f} - Audio length: {len(audiodata) / sr:.2f} seconds - Sample rate: {sr} Hz - Mean Zero Crossing Rate: {np.mean(zcr):.4f} - Mean RMS Energy: {np.mean(rms):.4f} - Tempo: {tempo[0]:.4f} - Beats: {beattimes} - Beat durations: {np.diff(beattimes)} - Mean Beat Duration: {np.mean(np.diff(beattimes)):.4f} """ return results, spectogram_wave, avg_beat_plot #----------------------------------------------- #----------------------------------------------- # HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2 def getBeatsv2(audio:gr.Audio): sr, audiodata = getaudiodata(audio) _, beattimes, audiodata = getBeats(audiodata, sr) beattimes_table = pd.DataFrame(data={"Beattimes":beattimes}) feature_array = find_s1s2(beattimes_table) featuredf = pd.DataFrame( data=feature_array, columns=[ "Beattimes", "S1 to S2", "S2 to S1", "Label (S1=0/S2=1)"] ) # Create boolean masks for each label mask_ones = feature_array[:, 3] == 1 mask_zeros = feature_array[:, 3] == 0 # Extract time/positions using the masks times_label_one = feature_array[mask_ones, 0] times_label_zero = feature_array[mask_zeros, 0] fig = plotBeattimes(times_label_one, audiodata, sr, times_label_zero) return fig, featuredf.to_markdown(), (sr, audiodata) def updateBeatsv2(beattimes_table:gr.Markdown, audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure: df = mdpd.from_md(df) sr, audiodata = getaudiodata(audio) if uploadeddf != None: beattimes_table = pd.read_csv(uploadeddf) s1_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 0]["Beattimes"].to_numpy() s2_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 1]["Beattimes"].to_numpy() fig = plotBeattimes(s1_times, audiodata, sr, s2_times) return fig, beattimes_table.to_markdown() def download_df (df: str): df = mdpd.from_md(df) print(df) temp_dir = tempfile.gettempdir() temp_path = os.path.join(temp_dir, "feature_data.csv") df.to_csv(temp_path, index=False) return temp_path with gr.Blocks() as app: gr.Markdown("# Heartbeat") gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios") audiofile = gr.Audio( type="filepath", label="Upload the Audio of a Heartbeat", sources="upload") with gr.Tab("Preprocessing"): getBeatsbtn = gr.Button("get Beats") cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True) beats_wave_plot = gr.Plot() with gr.Row(): with gr.Column(): beattimes_table = gr.Markdown() with gr.Column(): csv_download = gr.DownloadButton() updateBeatsbtn = gr.Button("update Beats") uploadDF = gr.File( file_count="single", file_types=[".csv"], label="upload a csv", height=25 ) csv_download.click(download_df, inputs=[beattimes_table], outputs=[csv_download]) getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table, cleanedaudio]) updateBeatsbtn.click(updateBeatsv2, inputs=[beattimes_table, audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_table]) gr.Examples( examples=example_files, inputs=audiofile, fn=getBeatsv2, cache_examples=False ) with gr.Tab("Analysis"): gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'") app.launch()