MiloSobral commited on
Commit
1737659
·
1 Parent(s): 35af352

Added sleep trains options on the tool

Browse files
portiloop/src/demo/demo.py CHANGED
@@ -29,10 +29,15 @@ def main():
29
  # Threshold value
30
  threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
31
  # Detection Channel
 
 
32
  detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
33
  # Frequency
34
  freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
35
 
 
 
 
36
  with gr.Row():
37
  output_array = gr.File(label="Output CSV File")
38
  output_table = gr.Markdown(label="Output Table")
@@ -45,7 +50,8 @@ def main():
45
  detect_filter,
46
  threshold,
47
  detect_channel,
48
- freq],
 
49
  outputs=[output_array, output_table])
50
 
51
  demo.queue()
 
29
  # Threshold value
30
  threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
31
  # Detection Channel
32
+
33
+ with gr.Row():
34
  detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
35
  # Frequency
36
  freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
37
 
38
+ # Detect trains dropdown
39
+ detect_trains = gr.Dropdown(choices=["All Spindles", "Isolated & First", "Trains"], value="All Spindles", label="Detection mode:", interactive=True)
40
+
41
  with gr.Row():
42
  output_array = gr.File(label="Output CSV File")
43
  output_table = gr.Markdown(label="Output Table")
 
50
  detect_filter,
51
  threshold,
52
  detect_channel,
53
+ freq,
54
+ detect_trains],
55
  outputs=[output_array, output_table])
56
 
57
  demo.queue()
portiloop/src/demo/offline.py CHANGED
@@ -2,11 +2,11 @@ import numpy as np
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
  from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
- from portiloop.src.demo.utils import compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
9
- def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
10
  # Get the options from the checkbox group
11
  offline_filtering = 0 in detect_filter_opts
12
  lacourse = 1 in detect_filter_opts
@@ -76,7 +76,14 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stim
76
  # Create the detector
77
  if online_detection:
78
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
79
- stimulator = OfflineSleepSpindleRealTimeStimulator()
 
 
 
 
 
 
 
80
  if stimulation_phase != "Fast":
81
  stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
82
  stimulator.add_delayer(stimulation_delayer)
 
2
  from portiloop.src.detection import SleepSpindleRealTimeDetector
3
  from portiloop.src.stimulation import UpStateDelayer
4
  from portiloop.src.processing import FilterPipeline
5
+ from portiloop.src.demo.utils import OfflineIsolatedSpindleRealTimeStimulator, OfflineSpindleTrainRealTimeStimulator, compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
6
  import gradio as gr
7
 
8
 
9
+ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, detect_trains, stimulation_phase="Fast", buffer_time=0.25):
10
  # Get the options from the checkbox group
11
  offline_filtering = 0 in detect_filter_opts
12
  lacourse = 1 in detect_filter_opts
 
76
  # Create the detector
77
  if online_detection:
78
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
79
+
80
+ if detect_trains == "All Spindles":
81
+ stimulator = OfflineSleepSpindleRealTimeStimulator()
82
+ elif detect_trains == "Trains":
83
+ stimulator = OfflineSpindleTrainRealTimeStimulator()
84
+ elif detect_trains == "Isolated & First":
85
+ stimulator = OfflineIsolatedSpindleRealTimeStimulator()
86
+
87
  if stimulation_phase != "Fast":
88
  stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
89
  stimulator.add_delayer(stimulation_delayer)
portiloop/src/demo/utils.py CHANGED
@@ -39,6 +39,7 @@ def sleep_stage(data, threshold=150, group_size=2):
39
  return unmasked_indices
40
 
41
 
 
42
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
43
  def __init__(self):
44
  self.last_detected_ts = time.time()
@@ -70,6 +71,54 @@ class OfflineSleepSpindleRealTimeStimulator(Stimulator):
70
  self.delayer = delayer
71
  self.delayer.stimulate = lambda: True
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def xdf2array(xdf_path, channel):
74
  xdf_data, _ = pyxdf.load_xdf(xdf_path)
75
 
 
39
  return unmasked_indices
40
 
41
 
42
+
43
  class OfflineSleepSpindleRealTimeStimulator(Stimulator):
44
  def __init__(self):
45
  self.last_detected_ts = time.time()
 
71
  self.delayer = delayer
72
  self.delayer.stimulate = lambda: True
73
 
74
+
75
+ class OfflineSpindleTrainRealTimeStimulator(OfflineSleepSpindleRealTimeStimulator):
76
+ def __init__(self):
77
+ super().__init__()
78
+ self.max_spindle_train_t = 6.0
79
+
80
+ def stimulate(self, detection_signal):
81
+ self.index += 1
82
+ stim = False
83
+ for sig in detection_signal:
84
+ # We detect a stimulation
85
+ if sig:
86
+ # Record time of stimulation
87
+ ts = self.index
88
+
89
+ elapsed = ts - self.last_detected_ts
90
+ # Check if time since last stimulation is long enough
91
+ if self.wait_timesteps < elapsed < int(self.max_spindle_train_t * 250):
92
+ if self.delayer is not None:
93
+ # If we have a delayer, notify it
94
+ self.delayer.detected()
95
+ stim = True
96
+
97
+ self.last_detected_ts = ts
98
+ return stim
99
+
100
+ class OfflineIsolatedSpindleRealTimeStimulator(OfflineSpindleTrainRealTimeStimulator):
101
+ def stimulate(self, detection_signal):
102
+ self.index += 1
103
+ stim = False
104
+ for sig in detection_signal:
105
+ # We detect a stimulation
106
+ if sig:
107
+ # Record time of stimulation
108
+ ts = self.index
109
+
110
+ elapsed = ts - self.last_detected_ts
111
+ # Check if time since last stimulation is long enough
112
+ if int(self.max_spindle_train_t * 250) < elapsed:
113
+ if self.delayer is not None:
114
+ # If we have a delayer, notify it
115
+ self.delayer.detected()
116
+ stim = True
117
+
118
+ self.last_detected_ts = ts
119
+ return stim
120
+
121
+
122
  def xdf2array(xdf_path, channel):
123
  xdf_data, _ = pyxdf.load_xdf(xdf_path)
124