MiloSobral commited on
Commit
2cb7306
·
1 Parent(s): 763def2

Finished setting up the demo and fixed my git stupidity

Browse files
portiloop/src/demo/demo.py CHANGED
@@ -1,137 +1,47 @@
1
  import gradio as gr
2
- import matplotlib.pyplot as plt
3
- import time
4
- import numpy as np
5
- import pandas as pd
6
- from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
7
- from portiloop.src.detection import SleepSpindleRealTimeDetector
8
 
9
- from portiloop.src.stimulation import UpStateDelayer
10
- plt.switch_backend('agg')
11
- from portiloop.src.processing import FilterPipeline
12
 
13
-
14
- def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):
15
-
16
- # Read the csv file to a numpy array
17
- data_whole = np.loadtxt(csv_file.name, delimiter=',')
18
-
19
- # Get the data from the selected channel
20
- detect_channel = int(detect_channel)
21
- freq = int(freq)
22
- data = data_whole[:, detect_channel - 1]
23
-
24
- # Create the detector and the stimulator
25
- detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
26
- stimulator = DemoSleepSpindleRealTimeStimulator()
27
- if spindle_detection_mode != 'Fast':
28
- delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
29
- stimulator.add_delayer(delayer)
30
-
31
- # Create the filtering pipeline
32
- if filtering:
33
- filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
34
-
35
- # Plotting variables
36
- points = []
37
- activations = []
38
- delayed_activations = []
39
-
40
- # Go through the data
41
- for index, point in enumerate(data):
42
- # Step the delayer if exists
43
- if spindle_detection_mode != 'Fast':
44
- delayed = delayer.step(point)
45
- if delayed:
46
- delayed_activations.append(1)
47
- else:
48
- delayed_activations.append(0)
49
-
50
- # Filter the data
51
- if filtering:
52
- filtered_point = filter.filter(np.array([point]))
53
- else:
54
- filtered_point = point
55
-
56
- filtered_point = filtered_point.tolist()
57
-
58
- # Detect the spindles
59
- result = detector.detect([filtered_point])
60
-
61
- # Stimulate if necessary
62
- stim = stimulator.stimulate(result)
63
- if stim:
64
- activations.append(1)
65
- else:
66
- activations.append(0)
67
-
68
- # Add data to plotting buffer
69
- points.append(filtered_point[0])
70
-
71
- # Function to return a list of all indexes where activations have happened
72
- def get_activations(activations):
73
- return [i for i, x in enumerate(activations) if x == 1]
74
-
75
- # Plot the data
76
- if index % (10 * freq) == 0 and index >= (10 * freq):
77
- plt.close()
78
- fig = plt.figure(figsize=(20, 10))
79
- plt.clf()
80
- plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
81
- # Draw vertical lines for activations
82
- for index in get_activations(activations[-10 * freq:]):
83
- plt.axvline(x=index / freq, color='r', label="Fast Stimulation")
84
- if spindle_detection_mode != 'Fast':
85
- for index in get_activations(delayed_activations[-10 * freq:]):
86
- plt.axvline(x=index / freq, color='g', label="Delayed Stimulation")
87
- # Add axis titles and legend
88
- plt.legend()
89
- plt.xlabel("Time (s)")
90
- plt.ylabel("Amplitude")
91
- yield fig, None
92
-
93
- # Put all points and activations back in numpy arrays
94
- points = np.array(points)
95
- activations = np.array(activations)
96
- delayed_activations = np.array(delayed_activations)
97
- # Concatenate with the original data
98
- data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
99
- # Output the data to a csv file
100
- np.savetxt('output.csv', data_whole, delimiter=',')
101
-
102
- yield None, "output.csv"
103
 
 
 
 
 
 
 
104
 
105
-
106
 
107
  with gr.Blocks() as demo:
108
  gr.Markdown("# Portiloop Demo")
109
  gr.Markdown("This Demo takes as input a csv file containing EEG data and outputs a csv file with the following added: \n * The data filtered by the Portiloop online filter \n * The stimulations made by Portiloop.")
110
  gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
111
 
112
- # Row containing all inputs:
113
  with gr.Row():
114
- # CSV file
115
- csv_file = gr.UploadButton(label="CSV File", file_count="single")
116
- # Filtering (Boolean)
117
- filtering = gr.Checkbox(label="Filtering (On/Off)", value=True)
 
 
 
 
 
 
 
 
 
118
  # Threshold value
119
  threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
120
  # Detection Channel
121
- detect_column = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], value="1", label="Detection Column in CSV", interactive=True)
122
  # Frequency
123
  freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
124
- # Spindle Frequency
125
- spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency (Hz)", interactive=True)
126
- # Spindle Detection Mode
127
- spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
128
- # Time to buffer
129
- time_to_buffer = gr.Slider(0, 1, value=0.3, step=0.01, label="Time to Buffer (s)", interactive=True)
130
 
131
- # Output plot
 
132
  output_plot = gr.Plot()
133
- # Output file
134
  output_array = gr.File(label="Output CSV File")
 
135
 
136
  # Row containing all buttons:
137
  with gr.Row():
@@ -139,7 +49,19 @@ with gr.Blocks() as demo:
139
  run_inference = gr.Button(value="Run Inference")
140
  # Reset button
141
  reset = gr.Button(value="Reset", variant="secondary")
142
- run_inference.click(fn=do_treatment, inputs=[csv_file, filtering, threshold, detect_column, freq, spindle_freq, spindle_detection_mode, time_to_buffer], outputs=[output_plot, output_array])
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  demo.queue()
145
  demo.launch(share=True)
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
+ from portiloop.src.demo.offline import run_offline
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ def on_upload_file(file):
7
+ # Check if file extension is .xdf
8
+ if file.name.split(".")[-1] != "xdf":
9
+ raise gr.Error("Please upload a .xdf file.")
10
+ else:
11
+ yield f"File {file.name} successfully uploaded!"
12
 
 
13
 
14
  with gr.Blocks() as demo:
15
  gr.Markdown("# Portiloop Demo")
16
  gr.Markdown("This Demo takes as input a csv file containing EEG data and outputs a csv file with the following added: \n * The data filtered by the Portiloop online filter \n * The stimulations made by Portiloop.")
17
  gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
18
 
 
19
  with gr.Row():
20
+ xdf_file = gr.UploadButton(label="Upload XDF File", type="file")
21
+
22
+ # Offline Filtering (Boolean)
23
+ offline_filtering = gr.Checkbox(label="Offline Filtering (On/Off)", value=True)
24
+ # Online Filtering (Boolean)
25
+ online_filtering = gr.Checkbox(label="Online Filtering (On/Off)", value=True)
26
+ # Lacourse's Method (Boolean)
27
+ lacourse = gr.Checkbox(label="Lacourse Detection (On/Off)", value=True)
28
+ # Wamsley's Method (Boolean)
29
+ wamsley = gr.Checkbox(label="Wamsley Detection (On/Off)", value=True)
30
+ # Online Detection (Boolean)
31
+ online_detection = gr.Checkbox(label="Online Detection (On/Off)", value=True)
32
+
33
  # Threshold value
34
  threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
35
  # Detection Channel
36
+ detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
37
  # Frequency
38
  freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
 
 
 
 
 
 
39
 
40
+ # Output elements
41
+ update_text = gr.Textbox(value="Waiting for user input...", label="Status", interactive=False)
42
  output_plot = gr.Plot()
 
43
  output_array = gr.File(label="Output CSV File")
44
+ xdf_file.upload(fn=on_upload_file, inputs=[xdf_file], outputs=[update_text])
45
 
46
  # Row containing all buttons:
47
  with gr.Row():
 
49
  run_inference = gr.Button(value="Run Inference")
50
  # Reset button
51
  reset = gr.Button(value="Reset", variant="secondary")
52
+ run_inference.click(
53
+ fn=run_offline,
54
+ inputs=[
55
+ xdf_file,
56
+ offline_filtering,
57
+ online_filtering,
58
+ online_detection,
59
+ lacourse,
60
+ wamsley,
61
+ threshold,
62
+ detect_channel,
63
+ freq],
64
+ outputs=[output_plot, output_array, update_text])
65
 
66
  demo.queue()
67
  demo.launch(share=True)
portiloop/src/demo/demo_stimulator.py DELETED
@@ -1,30 +0,0 @@
1
- import time
2
- from portiloop.src.stimulation import Stimulator
3
-
4
-
5
- class DemoSleepSpindleRealTimeStimulator(Stimulator):
6
- def __init__(self):
7
- self.last_detected_ts = time.time()
8
- self.wait_t = 0.4 # 400 ms
9
-
10
- def stimulate(self, detection_signal):
11
- stim = False
12
- for sig in detection_signal:
13
- # We detect a stimulation
14
- if sig:
15
- # Record time of stimulation
16
- ts = time.time()
17
-
18
- # Check if time since last stimulation is long enough
19
- if ts - self.last_detected_ts > self.wait_t:
20
- if self.delayer is not None:
21
- # If we have a delayer, notify it
22
- self.delayer.detected()
23
- stim = True
24
-
25
- self.last_detected_ts = ts
26
- return stim
27
-
28
- def add_delayer(self, delayer):
29
- self.delayer = delayer
30
- self.delayer.stimulate = lambda: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
portiloop/src/demo/offline.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from portiloop.src.detection import SleepSpindleRealTimeDetector
4
+ plt.switch_backend('agg')
5
+ from portiloop.src.processing import FilterPipeline
6
+ from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
7
+ import gradio as gr
8
+
9
+
10
+ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection, lacourse, wamsley, threshold, channel_num, freq):
11
+
12
+ print("Starting offline processing...")
13
+ # Make sure the inputs make sense:
14
+ if not offline_filtering and (lacourse or wamsley):
15
+ raise gr.Error("You can't use the offline detection methods without offline filtering.")
16
+
17
+ if not online_filtering and online_detection:
18
+ raise gr.Error("You can't use the online detection without online filtering.")
19
+
20
+ freq = int(freq)
21
+
22
+ # Read the xdf file to a numpy array
23
+ print("Loading xdf file...")
24
+ yield None, None, "Loading xdf file..."
25
+ data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
26
+ print(data_whole.shape)
27
+ # Do the offline filtering of the data
28
+ print("Filtering offline...")
29
+ yield None, None, "Filtering offline..."
30
+ if offline_filtering:
31
+ offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
32
+ # Expand the dimension of the filtered data to match the shape of the other columns
33
+ offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
34
+ data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
35
+ columns.append("offline_filtered_signal")
36
+
37
+ # Do Wamsley's method
38
+ print("Running Wamsley detection...")
39
+ yield None, None, "Running Wamsley detection..."
40
+ if wamsley:
41
+ wamsley_data = offline_detect("Wamsley", \
42
+ data_whole[:, columns.index("offline_filtered_signal")],\
43
+ data_whole[:, columns.index("time_stamps")],\
44
+ freq)
45
+ wamsley_data = np.expand_dims(wamsley_data, axis=1)
46
+ data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
47
+ columns.append("wamsley_spindles")
48
+
49
+ # Do Lacourse's method
50
+ print("Running Lacourse detection...")
51
+ yield None, None, "Running Lacourse detection..."
52
+ if lacourse:
53
+ lacourse_data = offline_detect("Lacourse", \
54
+ data_whole[:, columns.index("offline_filtered_signal")],\
55
+ data_whole[:, columns.index("time_stamps")],\
56
+ freq)
57
+ lacourse_data = np.expand_dims(lacourse_data, axis=1)
58
+ data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
59
+ columns.append("lacourse_spindles")
60
+
61
+ # Get the data from the raw signal column
62
+ data = data_whole[:, columns.index("raw_signal")]
63
+
64
+ # Create the online filtering pipeline
65
+ if online_filtering:
66
+ filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
67
+
68
+ # Create the detector
69
+ if online_detection:
70
+ detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
71
+ stimulator = OfflineSleepSpindleRealTimeStimulator()
72
+
73
+ print("Running online filtering and detection...")
74
+ yield None, None, "Running online filtering and detection..."
75
+ if online_filtering or online_detection:
76
+ # Plotting variables
77
+ points = []
78
+ online_activations = []
79
+
80
+ # Go through the data
81
+ for index, point in enumerate(data):
82
+ # Filter the data
83
+ if online_filtering:
84
+ filtered_point = filter.filter(np.array([point]))
85
+ else:
86
+ filtered_point = point
87
+ filtered_point = filtered_point.tolist()
88
+ points.append(filtered_point[0])
89
+
90
+ if online_detection:
91
+ # Detect the spindles
92
+ result = detector.detect([filtered_point])
93
+
94
+ # Stimulate if necessary
95
+ stim = stimulator.stimulate(result)
96
+ if stim:
97
+ online_activations.append(1)
98
+ else:
99
+ online_activations.append(0)
100
+
101
+ # Function to return a list of all indexes where activations have happened
102
+ def get_activations(activations):
103
+ return [i for i, x in enumerate(activations) if x == 1]
104
+
105
+ # Plot the data
106
+ if index % (10 * freq) == 0 and index >= (10 * freq):
107
+ plt.close()
108
+ fig = plt.figure(figsize=(20, 10))
109
+ plt.clf()
110
+ plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
111
+ # Draw vertical lines for activations
112
+ for index in get_activations(online_activations[-10 * freq:]):
113
+ plt.axvline(x=index / freq, color='r', label="Portiloop Stimulation")
114
+ # Add axis titles and legend
115
+ plt.legend()
116
+ plt.xlabel("Time (s)")
117
+ plt.ylabel("Amplitude")
118
+ yield fig, None, "Running online filtering and detection..."
119
+
120
+ if online_filtering:
121
+ online_filtered = np.array(points)
122
+ online_filtered = np.expand_dims(online_filtered, axis=1)
123
+ data_whole = np.concatenate((data_whole, online_filtered), axis=1)
124
+ columns.append("online_filtered_signal")
125
+
126
+ if online_detection:
127
+ online_activations = np.array(online_activations)
128
+ online_activations = np.expand_dims(online_activations, axis=1)
129
+ data_whole = np.concatenate((data_whole, online_activations), axis=1)
130
+ columns.append("online_stimulations")
131
+
132
+ print("Saving output...")
133
+ # Output the data to a csv file
134
+ np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
135
+
136
+ print("Done!")
137
+ yield None, "output.csv", "Done!"
portiloop/src/demo/test_offline.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import unittest
3
+ from portiloop.src.demo.offline import run_offline
4
+ from pathlib import Path
5
+
6
+
7
+ class TestOffline(unittest.TestCase):
8
+
9
+ def setUp(self):
10
+ combinatorial_config = {
11
+ 'offline_filtering': [True, False],
12
+ 'online_filtering': [True, False],
13
+ 'online_detection': [True, False],
14
+ 'wamsley': [True, False],
15
+ 'lacourse': [True, False],
16
+ }
17
+
18
+ self.exclusives = [("duplicate_as_window", "use_cnn_encoder")]
19
+
20
+ keys = list(combinatorial_config)
21
+ all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
22
+ all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
23
+ self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
24
+ self.xdf_file = Path(__file__).parents[3] / "test_xdf" / "test_file.xdf"
25
+
26
+
27
+ def test_all_options(self):
28
+ for config in self.filtered_options:
29
+ if config['online_detection']:
30
+ self.assertTrue(config['online_filtering'])
31
+
32
+ def test_single_option(self):
33
+ res = list(run_offline(
34
+ self.xdf_file,
35
+ offline_filtering=True,
36
+ online_filtering=True,
37
+ online_detection=True,
38
+ wamsley=True,
39
+ lacourse=True,
40
+ threshold=0.5,
41
+ channel_num=2,
42
+ freq=250))
43
+ print(res)
44
+
45
+ def tearDown(self):
46
+ pass
47
+
48
+ if __name__ == '__main__':
49
+ unittest.main()
portiloop/src/demo/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyxdf
3
+ from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
4
+ from scipy.signal import butter, filtfilt, iirnotch, detrend
5
+ import time
6
+ from portiloop.src.stimulation import Stimulator
7
+
8
+
9
+ STREAM_NAMES = {
10
+ 'filtered_data': 'Portiloop Filtered',
11
+ 'raw_data': 'Portiloop Raw Data',
12
+ 'stimuli': 'Portiloop_stimuli'
13
+ }
14
+
15
+
16
+ class OfflineSleepSpindleRealTimeStimulator(Stimulator):
17
+ def __init__(self):
18
+ self.last_detected_ts = time.time()
19
+ self.wait_t = 0.4 # 400 ms
20
+ self.delayer = None
21
+
22
+ def stimulate(self, detection_signal):
23
+ stim = False
24
+ for sig in detection_signal:
25
+ # We detect a stimulation
26
+ if sig:
27
+ # Record time of stimulation
28
+ ts = time.time()
29
+
30
+ # Check if time since last stimulation is long enough
31
+ if ts - self.last_detected_ts > self.wait_t:
32
+ if self.delayer is not None:
33
+ # If we have a delayer, notify it
34
+ self.delayer.detected()
35
+ stim = True
36
+
37
+ self.last_detected_ts = ts
38
+ return stim
39
+
40
+ def add_delayer(self, delayer):
41
+ self.delayer = delayer
42
+ self.delayer.stimulate = lambda: True
43
+
44
+ def xdf2array(xdf_path, channel):
45
+ xdf_data, _ = pyxdf.load_xdf(xdf_path)
46
+
47
+ # Load all streams given their names
48
+ filtered_stream, raw_stream, markers = None, None, None
49
+ for stream in xdf_data:
50
+ # print(stream['info']['name'])
51
+ if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
52
+ filtered_stream = stream
53
+ elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
54
+ raw_stream = stream
55
+ elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
56
+ markers = stream
57
+
58
+ if filtered_stream is None or raw_stream is None:
59
+ raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")
60
+
61
+ # Add all samples from raw and filtered signals
62
+ csv_list = []
63
+ diffs = []
64
+ shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
65
+ int(raw_stream['footer']['info']['sample_count'][0]))
66
+ for i in range(shortest_stream):
67
+ if markers is not None:
68
+ datapoint = [filtered_stream['time_stamps'][i],
69
+ float(filtered_stream['time_series'][i, channel-1]),
70
+ raw_stream['time_series'][i, channel-1],
71
+ 0]
72
+ else:
73
+ datapoint = [filtered_stream['time_stamps'][i],
74
+ float(filtered_stream['time_series'][i, channel-1]),
75
+ raw_stream['time_series'][i, channel-1]]
76
+ diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
77
+ csv_list.append(datapoint)
78
+
79
+ # Add markers
80
+ columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
81
+ if markers is not None:
82
+ columns.append("online_stimulations_portiloop")
83
+ for time_stamp in markers['time_stamps']:
84
+ new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
85
+ csv_list[new_index][3] = 1
86
+
87
+ return np.array(csv_list), columns
88
+
89
+
90
+ def offline_detect(method, data, timesteps, freq):
91
+ # Get the spindle data from the offline methods
92
+ time = np.arange(0, len(data)) / freq
93
+ if method == "Lacourse":
94
+ detector = DetectSpindle(method='Lacourse2018')
95
+ spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
96
+ elif method == "Wamsley":
97
+ detector = DetectSpindle(method='Wamsley2012')
98
+ spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
99
+ else:
100
+ raise ValueError("Invalid method")
101
+
102
+ # Convert the spindle data to a numpy array
103
+ spindle_result = np.zeros(data.shape)
104
+ for spindle in spindles:
105
+ start = spindle["start"]
106
+ end = spindle["end"]
107
+ # Find index of timestep closest to start and end
108
+ start_index = np.argmin(np.abs(timesteps - start))
109
+ end_index = np.argmin(np.abs(timesteps - end))
110
+ spindle_result[start_index:end_index] = 1
111
+ return spindle_result
112
+
113
+
114
+ def offline_filter(signal, freq):
115
+
116
+ # Notch filter
117
+ f0 = 60.0 # Frequency to be removed from signal (Hz)
118
+ Q = 100.0 # Quality factor
119
+ b, a = iirnotch(f0, Q, freq)
120
+ signal = filtfilt(b, a, signal)
121
+
122
+ # Bandpass filter
123
+ lowcut = 0.5
124
+ highcut = 40.0
125
+ order = 4
126
+ b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
127
+ signal = filtfilt(b, a, signal)
128
+
129
+ # Detrend the signal
130
+ signal = detrend(signal)
131
+
132
+ return signal
portiloop/src/detection.py CHANGED
@@ -154,7 +154,3 @@ class SleepSpindleRealTimeDetector(Detector):
154
  print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
155
 
156
  return output_data_y, output_data_h
157
-
158
-
159
-
160
-
 
154
  print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
155
 
156
  return output_data_y, output_data_h
 
 
 
 
setup.py CHANGED
@@ -20,6 +20,7 @@ setup(
20
  'pylsl-coral',
21
  'pyalsaaudio'],
22
  'PC': ['gradio',
23
- 'tensorflow',]
 
24
  },
25
  )
 
20
  'pylsl-coral',
21
  'pyalsaaudio'],
22
  'PC': ['gradio',
23
+ 'tensorflow',
24
+ 'pyxdf']
25
  },
26
  )