Heartbeat / utils.py
justus-tobias's picture
fixed issue with gradio df
b312806
import librosa
import numpy as np
import plotly.graph_objects as go
from scipy.signal import savgol_filter, find_peaks
from scipy.signal import butter, filtfilt, find_peaks
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import pywt
import pandas as pd
# GENERAL HELPER FUNCTIONS
def denoise_audio(audiodata: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
"""
Enhanced denoising of audio signals optimized for heart sounds.
Uses a combination of bandpass filtering, adaptive wavelet denoising,
and improved spectral subtraction.
Parameters:
-----------
audiodata : np.ndarray
Input audio signal (1D numpy array)
sr : int
Sampling rate in Hz
Returns:
--------
tuple[np.ndarray, int]
Tuple containing (denoised_signal, sampling_rate)
"""
# Input validation and conversion
if not isinstance(audiodata, np.ndarray) or audiodata.ndim != 1:
raise ValueError("audiodata must be a 1D numpy array")
if not isinstance(sr, int) or sr <= 0:
raise ValueError("sr must be a positive integer")
# Convert to float32 and normalize
audio = audiodata.astype(np.float32)
audio = audio / np.max(np.abs(audio))
# 1. Enhanced Bandpass Filter
# Optimize frequency range for heart sounds (20-200 Hz)
nyquist = sr / 2
low, high = 20 / nyquist, 200 / nyquist
order = 4 # Filter order
b, a = butter(order, [low, high], btype='band')
filtered = filtfilt(b, a, audio)
# 2. Adaptive Wavelet Denoising
def apply_wavelet_denoising(sig):
# Use sym4 wavelet (good for biomedical signals)
wavelet = 'sym4'
level = min(6, pywt.dwt_max_level(len(sig), pywt.Wavelet(wavelet).dec_len))
# Decompose signal
coeffs = pywt.wavedec(sig, wavelet, level=level)
# Adaptive thresholding based on level
for i in range(1, len(coeffs)):
# Calculate level-dependent threshold
sigma = np.median(np.abs(coeffs[i])) / 0.6745
threshold = sigma * np.sqrt(2 * np.log(len(coeffs[i])))
# Adjust threshold based on decomposition level
level_factor = 1 - (i / len(coeffs)) # Higher levels get lower thresholds
coeffs[i] = pywt.threshold(coeffs[i], threshold * level_factor, mode='soft')
return pywt.waverec(coeffs, wavelet)
# Apply wavelet denoising
denoised = apply_wavelet_denoising(filtered)
# Ensure consistent length
if len(denoised) != len(audio):
denoised = librosa.util.fix_length(denoised, len(audio))
# 3. Improved Spectral Subtraction
def spectral_subtract(sig):
# Parameters
frame_length = int(sr * 0.04) # 40ms frames
hop_length = frame_length // 2
# Compute STFT
D = librosa.stft(sig, n_fft=frame_length, hop_length=hop_length)
mag, phase = np.abs(D), np.angle(D)
# Estimate noise spectrum from low-energy frames
frame_energy = np.sum(mag**2, axis=0)
noise_threshold = np.percentile(frame_energy, 15)
noise_frames = mag[:, frame_energy < noise_threshold]
if noise_frames.size > 0:
noise_spectrum = np.median(noise_frames, axis=1)
# Oversubtraction factor (frequency-dependent)
freq_bins = np.fft.rfftfreq(frame_length, 1/sr)
alpha = 1.0 + 0.01 * (freq_bins / nyquist)
alpha = alpha[:len(noise_spectrum)].reshape(-1, 1)
# Spectral subtraction with flooring
mag_clean = np.maximum(mag - alpha * noise_spectrum.reshape(-1, 1), 0.01 * mag)
# Reconstruct signal
D_clean = mag_clean * np.exp(1j * phase)
return librosa.istft(D_clean, hop_length=hop_length)
return sig
# Apply spectral subtraction
final = spectral_subtract(denoised)
# Final normalization
final = final / np.max(np.abs(final))
return final, sr
def getaudiodata(filepath: str, target_sr: int = 16000) -> tuple[int, np.ndarray]:
"""
Load and process audio data with consistent output properties.
Parameters:
-----------
filepath : str
Path to the audio file
target_sr : int
Target sampling rate (default: 16000 Hz)
Returns:
--------
tuple[int, np.ndarray]
Sampling rate and processed audio data with consistent properties:
- dtype: float32
- shape: (N,) mono audio
- amplitude range: [-0.95, 0.95]
- no NaN or Inf values
- C-contiguous memory layout
"""
# Load audio with specified sampling rate
audiodata, sr = librosa.load(filepath, sr=target_sr)
# Ensure numpy array
audiodata = np.asarray(audiodata)
# Convert to mono if stereo
if len(audiodata.shape) > 1:
audiodata = np.mean(audiodata, axis=1)
# Handle any NaN or Inf values
audiodata = np.nan_to_num(audiodata, nan=0.0, posinf=0.0, neginf=0.0)
# Normalize to prevent clipping while maintaining relative amplitudes
max_abs = np.max(np.abs(audiodata))
if max_abs > 0: # Avoid division by zero
audiodata = audiodata * (0.95 / max_abs)
# Ensure float32 dtype and memory contiguous
audiodata = np.ascontiguousarray(audiodata, dtype=np.float32)
return sr, audiodata
def getBeats(audiodata: np.ndarray, sr: int, method='envelope') -> tuple[float, np.ndarray, np.ndarray]:
"""
Advanced heartbeat detection optimized for peak detection with improved sensitivity.
Parameters:
-----------
audiodata : np.ndarray
Audio time series
sr : int
Sampling rate
method : str
Detection method: 'onset', 'envelope', 'fusion' (default)
Returns:
--------
tempo : float
Estimated heart rate in BPM
peak_times : np.ndarray
Times of detected heartbeat peaks
cleaned_audio : np.ndarray
Cleaned audio signal
"""
# Denoise and normalize
audiodata, sr = denoise_audio(audiodata, sr)
cleaned_audio = audiodata / np.max(np.abs(audiodata))
def get_envelope_peaks():
"""Detect peaks using enhanced envelope method with better sensitivity"""
# Calculate envelope using appropriate frame sizes
hop_length = int(sr * 0.01) # 10ms hop
frame_length = int(sr * 0.04) # 40ms window
# Calculate RMS energy
rms = librosa.feature.rms(
y=cleaned_audio,
frame_length=frame_length,
hop_length=hop_length
)[0]
# Smooth the envelope (less aggressive smoothing)
rms_smooth = savgol_filter(rms, 7, 3)
# Find peaks with more lenient thresholds
peaks, properties = find_peaks(
rms_smooth,
distance=int(0.2 * (sr / hop_length)), # Minimum 0.2s between peaks (300 BPM max)
height=np.mean(rms_smooth) + 0.1 * np.std(rms_smooth), # Lower height threshold
prominence=np.mean(rms_smooth) * 0.1, # Lower prominence threshold
width=(int(0.01 * (sr / hop_length)), int(0.2 * (sr / hop_length))) # 10-200ms width
)
# Refine peak locations using original signal
refined_peaks = []
window_size = int(0.05 * sr) # 50ms window for refinement
for peak in peaks:
# Convert envelope peak to sample domain
sample_idx = peak * hop_length
# Define window boundaries
start = max(0, sample_idx - window_size//2)
end = min(len(cleaned_audio), sample_idx + window_size//2)
# Find the maximum amplitude within the window
window = np.abs(cleaned_audio[int(start):int(end)])
max_idx = np.argmax(window)
refined_peaks.append(start + max_idx)
return np.array(refined_peaks), rms_smooth
def get_onset_peaks():
"""Enhanced onset detection with better sensitivity"""
# Multi-band onset detection with adjusted parameters
onset_env = librosa.onset.onset_strength(
y=cleaned_audio,
sr=sr,
hop_length=256, # Smaller hop length for better temporal resolution
aggregate=np.median,
n_mels=128
)
# More lenient thresholding
threshold = np.mean(onset_env) + 0.3 * np.std(onset_env)
# Get onset positions
onset_frames = librosa.onset.onset_detect(
onset_envelope=onset_env,
sr=sr,
hop_length=256,
backtrack=True,
threshold=threshold,
pre_max=20, # 20 frames before peak
post_max=20, # 20 frames after peak
pre_avg=25, # 25 frames before for mean
post_avg=25, # 25 frames after for mean
wait=10 # Wait 10 frames before detecting next onset
)
# Refine onset positions to peaks
refined_peaks = []
window_size = int(0.05 * sr) # 50ms window
for frame in onset_frames:
# Convert frame to sample index
sample_idx = frame * 256 # Using hop_length=256
# Define window boundaries
start = max(0, sample_idx - window_size//2)
end = min(len(cleaned_audio), sample_idx + window_size//2)
# Find the maximum amplitude within the window
window = np.abs(cleaned_audio[int(start):int(end)])
max_idx = np.argmax(window)
refined_peaks.append(start + max_idx)
return np.array(refined_peaks), onset_env
# Apply selected method
if method == 'envelope':
peaks, _ = get_envelope_peaks()
elif method == 'onset':
peaks, _ = get_onset_peaks()
else: # fusion method
# Get peaks from both methods
env_peaks, _ = get_envelope_peaks()
onset_peaks, _ = get_onset_peaks()
# Merge nearby peaks (within 50ms)
all_peaks = np.sort(np.concatenate([env_peaks, onset_peaks]))
merged_peaks = []
last_peak = -np.inf
for peak in all_peaks:
if (peak - last_peak) / sr > 0.05: # 50ms minimum separation
merged_peaks.append(peak)
last_peak = peak
peaks = np.array(merged_peaks)
# Convert peaks to times
peak_times = peaks / sr
# Calculate tempo using peak times
if len(peak_times) > 1:
# Use weighted average of intervals
intervals = np.diff(peak_times)
tempos = 60 / intervals # Convert intervals to BPM
# Remove physiologically impossible tempos
valid_tempos = tempos[(tempos >= 30) & (tempos <= 300)]
if len(valid_tempos) > 0:
tempo = np.median(valid_tempos) # Use median for robustness
else:
tempo = 0
else:
tempo = 0
return tempo, peak_times, cleaned_audio
def plotBeattimes(beattimes: np.ndarray,
audiodata: np.ndarray,
sr: int,
beattimes2: np.ndarray = None) -> go.Figure:
"""
Plot audio waveform with beat markers for one or two sets of beat times.
Parameters:
-----------
beattimes : np.ndarray
Primary array of beat times in seconds (S1 beats if beattimes2 is provided)
audiodata : np.ndarray
Audio time series data
sr : int
Sampling rate
beattimes2 : np.ndarray, optional
Secondary array of beat times in seconds (S2 beats)
Returns:
--------
go.Figure
Plotly figure with waveform and beat markers
"""
# Calculate time array for the full audio
time = np.arange(len(audiodata)) / sr
# Create the figure
fig = go.Figure()
# Add waveform
fig.add_trace(
go.Scatter(
x=time,
y=audiodata,
mode='lines',
name='Waveform',
line=dict(color='blue', width=1)
)
)
# Process and plot primary beat times
beat_indices = np.round(beattimes * sr).astype(int)
beat_indices = beat_indices[beat_indices < len(audiodata)]
beat_amplitudes = audiodata[beat_indices]
# Define beat name based on whether secondary beats are provided
beat_name = "Beats S1" if beattimes2 is not None else "Beats"
# Add primary beat markers
fig.add_trace(
go.Scatter(
x=beattimes[beat_indices < len(audiodata)],
y=beat_amplitudes,
mode='markers',
name=beat_name,
marker=dict(
color='red',
size=8,
symbol='circle',
line=dict(color='darkred', width=1)
)
)
)
# Add primary beat vertical lines
for beat_time in beattimes[beat_indices < len(audiodata)]:
fig.add_vline(
x=beat_time,
line=dict(color="rgba(255, 0, 0, 0.2)", width=1),
layer="below"
)
# Process and plot secondary beat times if provided
if beattimes2 is not None:
beat_indices2 = np.round(beattimes2 * sr).astype(int)
beat_indices2 = beat_indices2[beat_indices2 < len(audiodata)]
beat_amplitudes2 = audiodata[beat_indices2]
# Add secondary beat markers
fig.add_trace(
go.Scatter(
x=beattimes2[beat_indices2 < len(audiodata)],
y=beat_amplitudes2,
mode='markers',
name="Beats S2",
marker=dict(
color='green',
size=8,
symbol='circle',
line=dict(color='darkgreen', width=1)
)
)
)
# Add secondary beat vertical lines
for beat_time in beattimes2[beat_indices2 < len(audiodata)]:
fig.add_vline(
x=beat_time,
line=dict(color="rgba(0, 255, 0, 0.2)", width=1),
layer="below"
)
# Update layout
fig.update_layout(
title="Audio Waveform with Beat Detection",
xaxis_title="Time (seconds)",
yaxis_title="Amplitude",
showlegend=True, # Changed to True to show beat types
hovermode='closest',
plot_bgcolor='white',
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
return fig
def iterate_beat_segments(beat_times, sr, audio):
"""
Iterate over audio segments between beats.
Parameters:
- beat_times: np.ndarray of beat times in seconds
- sr: Sample rate of the audio
- audio: np.ndarray of audio data
Yields:
- Tuple of (start_sample, end_sample, audio_segment)
"""
# Convert beat times to sample indices
beat_samples = librosa.time_to_samples(beat_times, sr=sr)
# Add start and end points
beat_samples = np.concatenate(([0], beat_samples, [len(audio)]))
# Iterate over pairs of beat samples
for start, end in zip(beat_samples[:-1], beat_samples[1:]):
# Extract the audio segment
segment = audio[start:end]
segment_metrics = segment_analysis(segment, sr)
def segment_analysis(segment, sr):
"""
Analyze an audio segment and compute various metrics.
Parameters:
- segment: np.ndarray of audio segment data
- sr: Sample rate of the audio
Returns:
- List of computed metrics
"""
# Duration
duration = len(segment) / sr
# RMS Energy
rms_energy = np.sqrt(np.mean(segment**2))
# Frequencies
# We'll use the mean of the magnitudes of the Fourier transform
fft_magnitudes = np.abs(np.fft.rfft(segment))
mean_frequency = np.mean(fft_magnitudes)
# Attempt to detect S1 and S2
# This is a simplified approach and may not be accurate for all cases
peaks, _ = find_peaks(np.abs(segment), distance=int(0.2*sr)) # Assume at least 0.2s between peaks
if len(peaks) >= 2:
s1_index, s2_index = peaks[:2]
s1_to_s2_duration = (s2_index - s1_index) / sr
s2_to_s1_duration = (len(segment) - s2_index + peaks[0]) / sr if len(peaks) > 2 else None
else:
s1_to_s2_duration = None
s2_to_s1_duration = None
return [
rms_energy,
mean_frequency,
duration,
s1_to_s2_duration,
s2_to_s1_duration
]
def find_s1s2(df:pd.DataFrame):
times = df['Beattimes'].to_numpy()
n_peaks = len(times)
# Initialize the feature array
feature_array = np.zeros((n_peaks, 4))
# Fill in the peak times (first column)
feature_array[:, 0] = times
# Calculate and fill distances to previous peaks (second column)
feature_array[1:, 1] = np.diff(times) # For all except first peak
feature_array[0, 1] = feature_array[1, 1] # First peak uses same as second
# Calculate and fill distances to next peaks (third column)
feature_array[:-1, 2] = np.diff(times) # For all except last peak
feature_array[-1, 2] = feature_array[-2, 2] # Last peak uses same as second-to-last
# Extract features (distances to prev and next peaks)
X = feature_array[:, 1:3]
# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Apply K-means clustering
kmeans = KMeans(n_clusters=2, random_state=42)
labels = kmeans.fit_predict(X_scaled)
# Update the labels in the feature array
feature_array[:, 3] = labels
return feature_array