justus-tobias commited on
Commit
a9a2704
·
1 Parent(s): 5fb3911

updated csv handling + HRV plot

Browse files
Files changed (2) hide show
  1. app.py +60 -33
  2. utils.py +63 -22
app.py CHANGED
@@ -265,46 +265,28 @@ def get_visualizations(beattimes_table: str, cleanedaudio: gr.Audio):
265
  # 3. HRV (Heart Rate Variability)
266
  s1_durations = []
267
  s2_durations = []
 
268
  for segment in segment_metrics:
269
- if segment['s1_to_s2_duration']: # Check if list is not empty
270
  s1_durations.extend(segment['s1_to_s2_duration'])
271
- if segment['s2_to_s1_duration']: # Check if list is not empty
272
  s2_durations.extend(segment['s2_to_s1_duration'])
273
-
274
- t_interp, sdnn_interp, rmssd_interp, hr_interp = u.compute_and_plot_hrv(s1_durations, s2_durations, sr)
275
-
276
- # Add each HRV metric as a separate trace
277
- fig.add_trace(
278
- go.Scatter(
279
- x=t_interp,
280
- y=sdnn_interp,
281
- name='SDNN',
282
- line=dict(color='red', width=1)
283
- ),
284
- row=3, col=1
285
- )
286
-
287
- fig.add_trace(
288
- go.Scatter(
289
- x=t_interp,
290
- y=rmssd_interp,
291
- name='RMSSD',
292
- line=dict(color='blue', width=1)
293
- ),
294
- row=3, col=1
295
- )
296
-
297
  fig.add_trace(
298
  go.Scatter(
299
- x=t_interp,
300
- y=hr_interp,
301
- name='Heart Rate',
302
- line=dict(color='green', width=1),
303
- yaxis='y2' # Use secondary y-axis for heart rate
304
  ),
305
  row=3, col=1
306
  )
307
-
308
 
309
  # 4. Average Heartbeat Waveform
310
  max_len = max(len(metric['segment']) for metric in segment_metrics)
@@ -356,7 +338,11 @@ def get_visualizations(beattimes_table: str, cleanedaudio: gr.Audio):
356
  plot_bgcolor='white',
357
  paper_bgcolor='white'
358
  )
359
-
 
 
 
 
360
 
361
  # Update y-axes for fixed scales where needed
362
  # fig.update_yaxes(range=[-0.05, 0.05], row=5, col=1) # Fixed y-axis for overlaid segments
@@ -374,6 +360,33 @@ def get_visualizations(beattimes_table: str, cleanedaudio: gr.Audio):
374
  fig.update_xaxes(title_text="Time (s)", row=5, col=1, gridcolor='lightgray')
375
 
376
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  #-----------------------------------------------
378
  #-----------------------------------------------
379
  # HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2
@@ -423,6 +436,17 @@ def updateBeatsv2(audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure:
423
  sep=";",
424
  decimal=",",
425
  encoding="utf-8-sig")
 
 
 
 
 
 
 
 
 
 
 
426
  else:
427
  raise FileNotFoundError("No file uploaded")
428
 
@@ -513,7 +537,10 @@ with gr.Blocks() as app:
513
 
514
  plot = gr.Plot()
515
 
 
 
516
  analyzebtn.click(get_visualizations, inputs=[beattimes_table, cleanedaudio], outputs=[plot])
 
517
 
518
 
519
  app.launch()
 
265
  # 3. HRV (Heart Rate Variability)
266
  s1_durations = []
267
  s2_durations = []
268
+
269
  for segment in segment_metrics:
270
+ if segment['s1_to_s2_duration']:
271
  s1_durations.extend(segment['s1_to_s2_duration'])
272
+ if segment['s2_to_s1_duration']:
273
  s2_durations.extend(segment['s2_to_s1_duration'])
274
+
275
+ # Compute HRV metrics
276
+ time, hrv_values, _ = u.compute_hrv(s1_durations, s2_durations, sr)
277
+
278
+ # Add HRV trace to the third subplot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  fig.add_trace(
280
  go.Scatter(
281
+ x=time,
282
+ y=hrv_values,
283
+ name='HRV (RMSSD)',
284
+ line=dict(color='blue', width=1.5),
285
+ hovertemplate='Time: %{x:.1f}s<br>HRV: %{y:.1f}ms<extra></extra>'
286
  ),
287
  row=3, col=1
288
  )
289
+
290
 
291
  # 4. Average Heartbeat Waveform
292
  max_len = max(len(metric['segment']) for metric in segment_metrics)
 
338
  plot_bgcolor='white',
339
  paper_bgcolor='white'
340
  )
341
+ # # Update layout for the HRV subplot
342
+ # fig.update_yaxes(title_text="Heart Rate (BPM)",
343
+ # overlaying='y',
344
+ # side='right',
345
+ # row=3, col=1)
346
 
347
  # Update y-axes for fixed scales where needed
348
  # fig.update_yaxes(range=[-0.05, 0.05], row=5, col=1) # Fixed y-axis for overlaid segments
 
360
  fig.update_xaxes(title_text="Time (s)", row=5, col=1, gridcolor='lightgray')
361
 
362
  return fig
363
+
364
+ def download_all(beattimes_table:str, cleanedaudio:gr.Audio):
365
+
366
+ df = mdpd.from_md(beattimes_table)
367
+ df['Beattimes'] = df['Beattimes'].astype(float)
368
+ df['Label (S1=1/S2=0)'] = df['Label (S1=1/S2=0)'].astype(int)
369
+
370
+ sr, audiodata = cleanedaudio
371
+
372
+ segment_metrics = u.compute_segment_metrics(df, sr, audiodata)
373
+
374
+ downloaddf = pd.DataFrame(segment_metrics)
375
+
376
+
377
+ # Convert numpy floats to regular floats
378
+ downloaddf['rms_energy'] = downloaddf['rms_energy'].astype(float)
379
+ downloaddf['mean_frequency'] = downloaddf['mean_frequency'].astype(float)
380
+
381
+ temp_dir = tempfile.gettempdir()
382
+ temp_path = os.path.join(temp_dir, "segment_metrics.csv")
383
+
384
+ downloaddf.to_csv(temp_path, index=False)
385
+
386
+ return temp_path
387
+
388
+
389
+
390
  #-----------------------------------------------
391
  #-----------------------------------------------
392
  # HELPER FUNCTIONS FOR SINGLE AUDIO ANALYSIS V2
 
436
  sep=";",
437
  decimal=",",
438
  encoding="utf-8-sig")
439
+
440
+ # Drop rows where all columns are NaN (empty)
441
+ beattimes_table = beattimes_table.dropna(how='all')
442
+
443
+ # Reset the index after dropping rows
444
+ beattimes_table = beattimes_table.reset_index(drop=True)
445
+
446
+ # Check if the column "Label (S1=0/S2=1)" exists and rename it
447
+ if "Label (S1=0/S2=1)" in beattimes_table.columns:
448
+ beattimes_table = beattimes_table.rename(columns={"Label (S1=0/S2=1)": "Label (S1=1/S2=0)"})
449
+
450
  else:
451
  raise FileNotFoundError("No file uploaded")
452
 
 
537
 
538
  plot = gr.Plot()
539
 
540
+ download_btn = gr.DownloadButton()
541
+
542
  analyzebtn.click(get_visualizations, inputs=[beattimes_table, cleanedaudio], outputs=[plot])
543
+ download_btn.click(download_all, inputs=[beattimes_table, cleanedaudio], outputs=[download_btn])
544
 
545
 
546
  app.launch()
utils.py CHANGED
@@ -575,29 +575,70 @@ def compute_segment_metrics(beattimes: pd.DataFrame, sr: int, audio: np.ndarray)
575
 
576
  return segment_metrics
577
 
578
- def compute_and_plot_hrv(s1_to_s2, s2_to_s1, sampling_rate=1000):
579
- # Combine s1_to_s2 and s2_to_s1 to get RR intervals
580
- rr_intervals = np.array(s1_to_s2) + np.array(s2_to_s1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  # Calculate cumulative time for each heartbeat
583
- time = np.cumsum(rr_intervals) / sampling_rate # Convert to seconds
584
 
585
  # Calculate instantaneous heart rate
586
- hr = 60 / rr_intervals # beats per minute
587
-
588
- # Compute rolling window HRV metrics
589
- window_size = 30 # 30-second window
590
- sdnn = np.array([np.std(rr_intervals[max(0, i-window_size):i+1])
591
- for i in range(len(rr_intervals))])
592
- rmssd = np.array([np.sqrt(np.mean(np.diff(rr_intervals[max(0, i-window_size):i+1])**2))
593
- for i in range(len(rr_intervals))])
594
-
595
- # Create evenly spaced time array for plotting
596
- t_interp = np.linspace(time.min(), time.max(), num=1000)
597
-
598
- # Interpolate HRV metrics for smooth plotting
599
- sdnn_interp = interp1d(time, sdnn, kind='cubic')(t_interp)
600
- rmssd_interp = interp1d(time, rmssd, kind='cubic')(t_interp)
601
- hr_interp = interp1d(time, hr, kind='cubic')(t_interp)
602
-
603
- return t_interp, sdnn_interp, rmssd_interp, hr_interp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
  return segment_metrics
577
 
578
+ def compute_hrv(s1_to_s2, s2_to_s1, sampling_rate=1000):
579
+ """
580
+ Compute Heart Rate Variability with debug statements
581
+ """
582
+ # Convert to numpy arrays if not already
583
+ s1_to_s2 = np.array(s1_to_s2)
584
+ s2_to_s1 = np.array(s2_to_s1)
585
+
586
+ # Debug: Print input values
587
+ print("First few s1_to_s2 values:", s1_to_s2[:5])
588
+ print("First few s2_to_s1 values:", s2_to_s1[:5])
589
+
590
+ # Calculate RR intervals (full cardiac cycle)
591
+ rr_intervals = s1_to_s2 + s2_to_s1
592
+
593
+ # Debug: Print RR intervals
594
+ print("First few RR intervals (samples):", rr_intervals[:5])
595
+
596
+ # Convert to seconds
597
+ rr_intervals = rr_intervals / sampling_rate
598
+ print("First few RR intervals (seconds):", rr_intervals[:5])
599
 
600
  # Calculate cumulative time for each heartbeat
601
+ time = np.cumsum(rr_intervals)
602
 
603
  # Calculate instantaneous heart rate
604
+ heart_rate = 60 / rr_intervals # beats per minute
605
+ print("First few heart rate values:", heart_rate[:5])
606
+
607
+ # Compute RMSSD using a rolling window
608
+ window_size = int(30 / np.mean(rr_intervals)) # Approximate 30-second window
609
+ print("Window size:", window_size)
610
+
611
+ hrv_values = []
612
+
613
+ for i in range(len(rr_intervals)):
614
+ window_start = max(0, i - window_size)
615
+ window_data = rr_intervals[window_start:i+1]
616
+ if len(window_data) > 1:
617
+ # Debug: Print window data occasionally
618
+ if i % 100 == 0:
619
+ print(f"\nWindow {i}:")
620
+ print("Window data:", window_data)
621
+ print("Successive differences:", np.diff(window_data))
622
+
623
+ successive_diffs = np.diff(window_data)
624
+ rmssd = np.sqrt(np.mean(successive_diffs ** 2)) * 1000 # Convert to ms
625
+ hrv_values.append(rmssd)
626
+ else:
627
+ hrv_values.append(np.nan)
628
+
629
+ hrv_values = np.array(hrv_values)
630
+
631
+ # Debug: Print HRV statistics
632
+ print("\nHRV Statistics:")
633
+ print("Min HRV:", np.nanmin(hrv_values))
634
+ print("Max HRV:", np.nanmax(hrv_values))
635
+ print("Mean HRV:", np.nanmean(hrv_values))
636
+ print("Number of valid HRV values:", np.sum(~np.isnan(hrv_values)))
637
+
638
+ # Remove potential NaN values at the start
639
+ valid_idx = ~np.isnan(hrv_values)
640
+ time = time[valid_idx]
641
+ hrv_values = hrv_values[valid_idx]
642
+ heart_rate = heart_rate[valid_idx]
643
+
644
+ return time, hrv_values, heart_rate