Heartbeat / utils.py
justus-tobias's picture
improved conversion
3d0b5bf
raw
history blame
21.9 kB
from scipy.signal import butter, filtfilt, find_peaks
from scipy.signal import savgol_filter, find_peaks
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import librosa
import pywt
# GENERAL HELPER FUNCTIONS
def denoise_audio(audiodata: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
"""
Enhanced denoising of audio signals optimized for heart sounds.
Uses a combination of bandpass filtering, adaptive wavelet denoising,
and improved spectral subtraction.
Parameters:
-----------
audiodata : np.ndarray
Input audio signal (1D numpy array)
sr : int
Sampling rate in Hz
Returns:
--------
tuple[np.ndarray, int]
Tuple containing (denoised_signal, sampling_rate)
"""
# Input validation and conversion
if not isinstance(audiodata, np.ndarray) or audiodata.ndim != 1:
raise ValueError("audiodata must be a 1D numpy array")
if not isinstance(sr, int) or sr <= 0:
raise ValueError("sr must be a positive integer")
# Convert to float32 and normalize
audio = audiodata.astype(np.float32)
audio = audio / np.max(np.abs(audio))
# 1. Enhanced Bandpass Filter
# Optimize frequency range for heart sounds (20-200 Hz)
nyquist = sr / 2
low, high = 20 / nyquist, 200 / nyquist
order = 4 # Filter order
b, a = butter(order, [low, high], btype='band')
filtered = filtfilt(b, a, audio)
# 2. Adaptive Wavelet Denoising
def apply_wavelet_denoising(sig):
# Use sym4 wavelet (good for biomedical signals)
wavelet = 'sym4'
level = min(6, pywt.dwt_max_level(len(sig), pywt.Wavelet(wavelet).dec_len))
# Decompose signal
coeffs = pywt.wavedec(sig, wavelet, level=level)
# Adaptive thresholding based on level
for i in range(1, len(coeffs)):
# Calculate level-dependent threshold
sigma = np.median(np.abs(coeffs[i])) / 0.6745
threshold = sigma * np.sqrt(2 * np.log(len(coeffs[i])))
# Adjust threshold based on decomposition level
level_factor = 1 - (i / len(coeffs)) # Higher levels get lower thresholds
coeffs[i] = pywt.threshold(coeffs[i], threshold * level_factor, mode='soft')
return pywt.waverec(coeffs, wavelet)
# Apply wavelet denoising
denoised = apply_wavelet_denoising(filtered)
# Ensure consistent length
if len(denoised) != len(audio):
denoised = librosa.util.fix_length(denoised, size=len(audio))
# 3. Improved Spectral Subtraction
def spectral_subtract(sig):
# Parameters
frame_length = int(sr * 0.04) # 40ms frames
hop_length = frame_length // 2
# Compute STFT
D = librosa.stft(sig, n_fft=frame_length, hop_length=hop_length)
mag, phase = np.abs(D), np.angle(D)
# Estimate noise spectrum from low-energy frames
frame_energy = np.sum(mag**2, axis=0)
noise_threshold = np.percentile(frame_energy, 15)
noise_frames = mag[:, frame_energy < noise_threshold]
if noise_frames.size > 0:
noise_spectrum = np.median(noise_frames, axis=1)
# Oversubtraction factor (frequency-dependent)
freq_bins = np.fft.rfftfreq(frame_length, 1/sr)
alpha = 1.0 + 0.01 * (freq_bins / nyquist)
alpha = alpha[:len(noise_spectrum)].reshape(-1, 1)
# Spectral subtraction with flooring
mag_clean = np.maximum(mag - alpha * noise_spectrum.reshape(-1, 1), 0.01 * mag)
# Reconstruct signal
D_clean = mag_clean * np.exp(1j * phase)
return librosa.istft(D_clean, hop_length=hop_length)
return sig
# Apply spectral subtraction
final = spectral_subtract(denoised)
# Final normalization
final = final / np.max(np.abs(final))
return final, sr
def getaudiodata(filepath: str, target_sr: int = 16000) -> tuple[int, np.ndarray]:
"""
Load and process audio data with consistent output properties.
Parameters:
-----------
filepath : str
Path to the audio file
target_sr : int
Target sampling rate (default: 16000 Hz)
Returns:
--------
tuple[int, np.ndarray]
Sampling rate and processed audio data with consistent properties:
- dtype: float32
- shape: (N,) mono audio
- amplitude range: [-0.95, 0.95]
- no NaN or Inf values
- C-contiguous memory layout
"""
# Load audio with specified sampling rate
audiodata, sr = librosa.load(filepath, sr=target_sr)
# Ensure numpy array
audiodata = np.asarray(audiodata)
# Convert to mono if stereo
if len(audiodata.shape) > 1:
audiodata = np.mean(audiodata, axis=1)
# Handle any NaN or Inf values
audiodata = np.nan_to_num(audiodata, nan=0.0, posinf=0.0, neginf=0.0)
# Normalize to prevent clipping while maintaining relative amplitudes
max_abs = np.max(np.abs(audiodata))
if max_abs > 0: # Avoid division by zero
audiodata = audiodata * (0.95 / max_abs)
# Ensure float32 dtype and memory contiguous
audiodata = np.ascontiguousarray(audiodata, dtype=np.float32)
return sr, audiodata
def getBeats(audiodata: np.ndarray, sr: int, method='envelope') -> tuple[float, np.ndarray, np.ndarray]:
"""
Advanced heartbeat detection optimized for peak detection with improved sensitivity.
Parameters:
-----------
audiodata : np.ndarray
Audio time series
sr : int
Sampling rate
method : str
Detection method: 'onset', 'envelope', 'fusion' (default)
Returns:
--------
tempo : float
Estimated heart rate in BPM
peak_times : np.ndarray
Times of detected heartbeat peaks
cleaned_audio : np.ndarray
Cleaned audio signal
"""
# Denoise and normalize
audiodata, sr = denoise_audio(audiodata, sr)
# Normalize to prevent clipping while maintaining relative amplitudes
cleaned_audio = audiodata / np.max(np.abs(audiodata))
def get_envelope_peaks():
"""Detect peaks using enhanced envelope method with better sensitivity"""
# Calculate envelope using appropriate frame sizes
hop_length = int(sr * 0.01) # 10ms hop
frame_length = int(sr * 0.04) # 40ms window
# Calculate RMS energy
rms = librosa.feature.rms(
y=cleaned_audio,
frame_length=frame_length,
hop_length=hop_length
)[0]
# Smooth the envelope (less aggressive smoothing)
rms_smooth = savgol_filter(rms, 7, 3)
# Find peaks with more lenient thresholds
peaks, properties = find_peaks(
rms_smooth,
distance=int(0.2 * (sr / hop_length)), # Minimum 0.2s between peaks (300 BPM max)
height=np.mean(rms_smooth) + 0.1 * np.std(rms_smooth), # Lower height threshold
prominence=np.mean(rms_smooth) * 0.1, # Lower prominence threshold
width=(int(0.01 * (sr / hop_length)), int(0.2 * (sr / hop_length))) # 10-200ms width
)
# Refine peak locations using original signal
refined_peaks = []
window_size = int(0.05 * sr) # 50ms window for refinement
for peak in peaks:
# Convert envelope peak to sample domain
sample_idx = peak * hop_length
# Define window boundaries
start = max(0, sample_idx - window_size//2)
end = min(len(cleaned_audio), sample_idx + window_size//2)
# Find the maximum amplitude within the window
window = np.abs(cleaned_audio[int(start):int(end)])
max_idx = np.argmax(window)
refined_peaks.append(start + max_idx)
return np.array(refined_peaks), rms_smooth
def get_onset_peaks():
"""Enhanced onset detection with better sensitivity"""
# Multi-band onset detection with adjusted parameters
onset_env = librosa.onset.onset_strength(
y=cleaned_audio,
sr=sr,
hop_length=256, # Smaller hop length for better temporal resolution
aggregate=np.median,
n_mels=128
)
# More lenient thresholding
threshold = np.mean(onset_env) + 0.3 * np.std(onset_env)
# Get onset positions
onset_frames = librosa.onset.onset_detect(
onset_envelope=onset_env,
sr=sr,
hop_length=256,
backtrack=True,
threshold=threshold,
pre_max=20, # 20 frames before peak
post_max=20, # 20 frames after peak
pre_avg=25, # 25 frames before for mean
post_avg=25, # 25 frames after for mean
wait=10 # Wait 10 frames before detecting next onset
)
# Refine onset positions to peaks
refined_peaks = []
window_size = int(0.05 * sr) # 50ms window
for frame in onset_frames:
# Convert frame to sample index
sample_idx = frame * 256 # Using hop_length=256
# Define window boundaries
start = max(0, sample_idx - window_size//2)
end = min(len(cleaned_audio), sample_idx + window_size//2)
# Find the maximum amplitude within the window
window = np.abs(cleaned_audio[int(start):int(end)])
max_idx = np.argmax(window)
refined_peaks.append(start + max_idx)
return np.array(refined_peaks), onset_env
# Apply selected method
if method == 'envelope':
peaks, _ = get_envelope_peaks()
elif method == 'onset':
peaks, _ = get_onset_peaks()
else: # fusion method
# Get peaks from both methods
env_peaks, _ = get_envelope_peaks()
onset_peaks, _ = get_onset_peaks()
# Merge nearby peaks (within 50ms)
all_peaks = np.sort(np.concatenate([env_peaks, onset_peaks]))
merged_peaks = []
last_peak = -np.inf
for peak in all_peaks:
if (peak - last_peak) / sr > 0.05: # 50ms minimum separation
merged_peaks.append(peak)
last_peak = peak
peaks = np.array(merged_peaks)
# Convert peaks to times
peak_times = peaks / sr
# Calculate tempo using peak times
if len(peak_times) > 1:
# Use weighted average of intervals
intervals = np.diff(peak_times)
tempos = 60 / intervals # Convert intervals to BPM
# Remove physiologically impossible tempos
valid_tempos = tempos[(tempos >= 30) & (tempos <= 300)]
if len(valid_tempos) > 0:
tempo = np.median(valid_tempos) # Use median for robustness
else:
tempo = 0
else:
tempo = 0
return tempo, peak_times, cleaned_audio
def plotBeattimes(beattimes: np.ndarray,
audiodata: np.ndarray,
sr: int,
beattimes2: np.ndarray = None) -> go.Figure:
"""
Plot audio waveform with beat markers for one or two sets of beat times.
Parameters:
-----------
beattimes : np.ndarray
Primary array of beat times in seconds (S1 beats if beattimes2 is provided)
audiodata : np.ndarray
Audio time series data
sr : int
Sampling rate
beattimes2 : np.ndarray, optional
Secondary array of beat times in seconds (S2 beats)
Returns:
--------
go.Figure
Plotly figure with waveform and beat markers
"""
# Calculate time array for the full audio
time = np.arange(len(audiodata)) / sr
# Create the figure
fig = go.Figure()
# Add waveform
fig.add_trace(
go.Scatter(
x=time,
y=audiodata,
mode='lines',
name='Waveform',
line=dict(color='blue', width=1)
)
)
# Process and plot primary beat times
if isinstance(beattimes[0], str):
beat_indices = np.round(np.array([float(bt.replace(',', '.')) for bt in beattimes]) * sr).astype(int)
else:
beat_indices = np.round(beattimes * sr).astype(int)
beat_indices = beat_indices[beat_indices < len(audiodata)]
beat_amplitudes = audiodata[beat_indices]
# Define beat name based on whether secondary beats are provided
beat_name = "Beats S1" if beattimes2 is not None else "Beats"
# Add primary beat markers
fig.add_trace(
go.Scatter(
x=beattimes[beat_indices < len(audiodata)],
y=beat_amplitudes,
mode='markers',
name=beat_name,
marker=dict(
color='red',
size=8,
symbol='circle',
line=dict(color='darkred', width=1)
)
)
)
# Add primary beat vertical lines
for beat_time in beattimes[beat_indices < len(audiodata)]:
fig.add_vline(
x=beat_time,
line=dict(color="rgba(255, 0, 0, 0.2)", width=1),
layer="below"
)
# Process and plot secondary beat times if provided
if beattimes2 is not None:
if isinstance(beattimes2[0], str):
beat_indices2 = np.round(np.array([float(bt.replace(',', '.')) for bt in beattimes2]) * sr).astype(int)
else:
beat_indices2 = np.round(beattimes2 * sr).astype(int)
beat_indices2 = beat_indices2[beat_indices2 < len(audiodata)]
beat_amplitudes2 = audiodata[beat_indices2]
# Add secondary beat markers
fig.add_trace(
go.Scatter(
x=beattimes2[beat_indices2 < len(audiodata)],
y=beat_amplitudes2,
mode='markers',
name="Beats S2",
marker=dict(
color='green',
size=8,
symbol='circle',
line=dict(color='darkgreen', width=1)
)
)
)
# Add secondary beat vertical lines
for beat_time in beattimes2[beat_indices2 < len(audiodata)]:
fig.add_vline(
x=beat_time,
line=dict(color="rgba(0, 255, 0, 0.2)", width=1),
layer="below"
)
# Update layout
fig.update_layout(
title="Audio Waveform with Beat Detection",
xaxis_title="Time (seconds)",
yaxis_title="Amplitude",
showlegend=True, # Changed to True to show beat types
hovermode='closest',
plot_bgcolor='white',
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
return fig
def iterate_beat_segments(beat_times, sr, audio):
"""
Iterate over audio segments between beats marked with label 1.
Parameters:
- beat_times: df of beattimes and labels as DataFrame
- sr: Sample rate of the audio
- audio: np.ndarray of audio data
Yields:
- List of segment metrics with associated beat information
"""
# Get indices where label is 1
label_ones = beat_times[beat_times['Label (S1=1/S2=0)'] == 1].index.tolist()
segment_metrics = []
# Iterate through pairs of label 1 indices
for i in range(len(label_ones) - 1):
start_idx = label_ones[i]
end_idx = label_ones[i + 1]
# Get all beats between two label 1 beats (inclusive)
segment_beats = beat_times.iloc[start_idx:end_idx + 1]
# Create list of tuples (label, beattime)
beat_info = list(zip(segment_beats['Label (S1=1/S2=0)'],
segment_beats['Beattimes']))
# Get start and end samples
start_sample = librosa.time_to_samples(segment_beats.iloc[0]['Beattimes'], sr=sr)
end_sample = librosa.time_to_samples(segment_beats.iloc[-1]['Beattimes'], sr=sr)
# Extract audio segment
segment = audio[start_sample:end_sample]
# Analyze segment with beat information if not empty
if len(segment) > 0:
segment_metrics.append(segment_analysis(segment, sr, beat_info))
return segment_metrics
def segment_analysis(segment, sr, s1s2:list):
"""
Analyze an audio segment and compute various metrics.
Parameters:
- segment: np.ndarray of audio segment data
- sr: Sample rate of the audio
Returns:
- List of computed metrics
"""
# Duration
duration = len(segment) / sr
# RMS Energy
rms_energy = np.sqrt(np.mean(segment**2))
# Calculate frequency spectrum and find dominant frequencies
fft = np.abs(np.fft.rfft(segment))
freqs = np.fft.rfftfreq(len(segment), d=1/sr)
# Focus on frequency range typical for heart sounds (20-200 Hz)
mask = (freqs >= 20) & (freqs <= 200)
dominant_freq_idx = np.argmax(fft[mask])
mean_frequency = freqs[mask][dominant_freq_idx]
s1_to_s2_duration = []
s2_to_s1_duration = []
prev = s1s2[0]
for i in range(1, len(s1s2)):
if prev[0] == 0 and s1s2[i][0] == 1:
s2_to_s1_duration.append(s1s2[i][1] - prev[1])
elif prev[0] == 1 and s1s2[i][0] == 0:
s1_to_s2_duration.append(s1s2[i][1] - prev[1])
prev = s1s2[i]
return {
"rms_energy": rms_energy,
"mean_frequency": mean_frequency,
"duration": duration,
"s1_to_s2_duration": s1_to_s2_duration,
"s2_to_s1_duration": s2_to_s1_duration,
"segment": segment
}
def find_s1s2(df:pd.DataFrame):
times = df['Beattimes'].to_numpy()
n_peaks = len(times)
# Initialize the feature array
feature_array = np.zeros((n_peaks, 4))
# Fill in the peak times (first column)
feature_array[:, 0] = times
# Calculate and fill distances to previous peaks (second column)
feature_array[1:, 1] = np.diff(times) # For all except first peak
feature_array[0, 1] = feature_array[1, 1] # First peak uses same as second
# Calculate and fill distances to next peaks (third column)
feature_array[:-1, 2] = np.diff(times) # For all except last peak
feature_array[-1, 2] = feature_array[-2, 2] # Last peak uses same as second-to-last
# Extract features (distances to prev and next peaks)
X = feature_array[:, 1:3]
# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Apply K-means clustering
kmeans = KMeans(n_clusters=2, random_state=42)
labels = kmeans.fit_predict(X_scaled)
# Update the labels in the feature array
feature_array[:, 3] = labels
return feature_array
# ANALYZE
def compute_segment_metrics(beattimes: pd.DataFrame, sr: int, audio: np.ndarray):
beattimes[beattimes['Label (S1=1/S2=0)'] == 1]
segment_metrics = iterate_beat_segments(beattimes, sr, audio)
print("segment_metrics", segment_metrics)
return segment_metrics
def compute_hrv(s1_to_s2, s2_to_s1, sampling_rate):
"""
Compute Heart Rate Variability with debug statements
"""
# Convert to numpy arrays if not already
s1_to_s2 = np.array(s1_to_s2)
s2_to_s1 = np.array(s2_to_s1)
# Debug: Print input values
print("First few s1_to_s2 values:", s1_to_s2[:5])
print("First few s2_to_s1 values:", s2_to_s1[:5])
# Calculate RR intervals (full cardiac cycle)
rr_intervals = s1_to_s2 + s2_to_s1
# Debug: Print RR intervals
print("First few RR intervals (samples):", rr_intervals[:5])
# Convert to seconds
rr_intervals = rr_intervals / sampling_rate
print("First few RR intervals (seconds):", rr_intervals[:5])
# Calculate cumulative time for each heartbeat
time = np.cumsum(rr_intervals)
# Calculate instantaneous heart rate
heart_rate = 60 / rr_intervals # beats per minute
print("First few heart rate values:", heart_rate[:5])
# Compute RMSSD using a rolling window
window_size = int(30 / np.mean(rr_intervals)) # Approximate 30-second window
print("Window size:", window_size)
hrv_values = []
for i in range(len(rr_intervals)):
window_start = max(0, i - window_size)
window_data = rr_intervals[window_start:i+1]
if len(window_data) > 1:
# Debug: Print window data occasionally
if i % 100 == 0:
print(f"\nWindow {i}:")
print("Window data:", window_data)
print("Successive differences:", np.diff(window_data))
successive_diffs = np.diff(window_data)
rmssd = np.sqrt(np.mean(successive_diffs ** 2)) * 1000 # Convert to ms
hrv_values.append(rmssd)
else:
hrv_values.append(np.nan)
hrv_values = np.array(hrv_values)
# Debug: Print HRV statistics
print("\nHRV Statistics:")
print("Min HRV:", np.nanmin(hrv_values))
print("Max HRV:", np.nanmax(hrv_values))
print("Mean HRV:", np.nanmean(hrv_values))
print("Number of valid HRV values:", np.sum(~np.isnan(hrv_values)))
# Remove potential NaN values at the start
valid_idx = ~np.isnan(hrv_values)
time = time[valid_idx]
hrv_values = hrv_values[valid_idx]
heart_rate = heart_rate[valid_idx]
return time, hrv_values, heart_rate