Spaces:
Sleeping
Sleeping
Commit
·
2cb7306
1
Parent(s):
763def2
Finished setting up the demo and fixed my git stupidity
Browse files- portiloop/src/demo/demo.py +37 -115
- portiloop/src/demo/demo_stimulator.py +0 -30
- portiloop/src/demo/offline.py +137 -0
- portiloop/src/demo/test_offline.py +49 -0
- portiloop/src/demo/utils.py +132 -0
- portiloop/src/detection.py +0 -4
- setup.py +2 -1
portiloop/src/demo/demo.py
CHANGED
@@ -1,137 +1,47 @@
|
|
1 |
import gradio as gr
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
import time
|
4 |
-
import numpy as np
|
5 |
-
import pandas as pd
|
6 |
-
from portiloop.src.demo.demo_stimulator import DemoSleepSpindleRealTimeStimulator
|
7 |
-
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
8 |
|
9 |
-
from portiloop.src.
|
10 |
-
plt.switch_backend('agg')
|
11 |
-
from portiloop.src.processing import FilterPipeline
|
12 |
|
13 |
-
|
14 |
-
def do_treatment(csv_file, filtering, threshold, detect_channel, freq, spindle_freq, spindle_detection_mode, time_to_buffer):
|
15 |
-
|
16 |
-
# Read the csv file to a numpy array
|
17 |
-
data_whole = np.loadtxt(csv_file.name, delimiter=',')
|
18 |
-
|
19 |
-
# Get the data from the selected channel
|
20 |
-
detect_channel = int(detect_channel)
|
21 |
-
freq = int(freq)
|
22 |
-
data = data_whole[:, detect_channel - 1]
|
23 |
-
|
24 |
-
# Create the detector and the stimulator
|
25 |
-
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
26 |
-
stimulator = DemoSleepSpindleRealTimeStimulator()
|
27 |
-
if spindle_detection_mode != 'Fast':
|
28 |
-
delayer = UpStateDelayer(freq, spindle_freq, spindle_detection_mode == 'Peak', time_to_buffer=time_to_buffer)
|
29 |
-
stimulator.add_delayer(delayer)
|
30 |
-
|
31 |
-
# Create the filtering pipeline
|
32 |
-
if filtering:
|
33 |
-
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
|
34 |
-
|
35 |
-
# Plotting variables
|
36 |
-
points = []
|
37 |
-
activations = []
|
38 |
-
delayed_activations = []
|
39 |
-
|
40 |
-
# Go through the data
|
41 |
-
for index, point in enumerate(data):
|
42 |
-
# Step the delayer if exists
|
43 |
-
if spindle_detection_mode != 'Fast':
|
44 |
-
delayed = delayer.step(point)
|
45 |
-
if delayed:
|
46 |
-
delayed_activations.append(1)
|
47 |
-
else:
|
48 |
-
delayed_activations.append(0)
|
49 |
-
|
50 |
-
# Filter the data
|
51 |
-
if filtering:
|
52 |
-
filtered_point = filter.filter(np.array([point]))
|
53 |
-
else:
|
54 |
-
filtered_point = point
|
55 |
-
|
56 |
-
filtered_point = filtered_point.tolist()
|
57 |
-
|
58 |
-
# Detect the spindles
|
59 |
-
result = detector.detect([filtered_point])
|
60 |
-
|
61 |
-
# Stimulate if necessary
|
62 |
-
stim = stimulator.stimulate(result)
|
63 |
-
if stim:
|
64 |
-
activations.append(1)
|
65 |
-
else:
|
66 |
-
activations.append(0)
|
67 |
-
|
68 |
-
# Add data to plotting buffer
|
69 |
-
points.append(filtered_point[0])
|
70 |
-
|
71 |
-
# Function to return a list of all indexes where activations have happened
|
72 |
-
def get_activations(activations):
|
73 |
-
return [i for i, x in enumerate(activations) if x == 1]
|
74 |
-
|
75 |
-
# Plot the data
|
76 |
-
if index % (10 * freq) == 0 and index >= (10 * freq):
|
77 |
-
plt.close()
|
78 |
-
fig = plt.figure(figsize=(20, 10))
|
79 |
-
plt.clf()
|
80 |
-
plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
|
81 |
-
# Draw vertical lines for activations
|
82 |
-
for index in get_activations(activations[-10 * freq:]):
|
83 |
-
plt.axvline(x=index / freq, color='r', label="Fast Stimulation")
|
84 |
-
if spindle_detection_mode != 'Fast':
|
85 |
-
for index in get_activations(delayed_activations[-10 * freq:]):
|
86 |
-
plt.axvline(x=index / freq, color='g', label="Delayed Stimulation")
|
87 |
-
# Add axis titles and legend
|
88 |
-
plt.legend()
|
89 |
-
plt.xlabel("Time (s)")
|
90 |
-
plt.ylabel("Amplitude")
|
91 |
-
yield fig, None
|
92 |
-
|
93 |
-
# Put all points and activations back in numpy arrays
|
94 |
-
points = np.array(points)
|
95 |
-
activations = np.array(activations)
|
96 |
-
delayed_activations = np.array(delayed_activations)
|
97 |
-
# Concatenate with the original data
|
98 |
-
data_whole = np.concatenate((data_whole, points.reshape(-1, 1), activations.reshape(-1, 1), delayed_activations.reshape(-1, 1)), axis=1)
|
99 |
-
# Output the data to a csv file
|
100 |
-
np.savetxt('output.csv', data_whole, delimiter=',')
|
101 |
-
|
102 |
-
yield None, "output.csv"
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
106 |
|
107 |
with gr.Blocks() as demo:
|
108 |
gr.Markdown("# Portiloop Demo")
|
109 |
gr.Markdown("This Demo takes as input a csv file containing EEG data and outputs a csv file with the following added: \n * The data filtered by the Portiloop online filter \n * The stimulations made by Portiloop.")
|
110 |
gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
|
111 |
|
112 |
-
# Row containing all inputs:
|
113 |
with gr.Row():
|
114 |
-
|
115 |
-
|
116 |
-
# Filtering (Boolean)
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
# Threshold value
|
119 |
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
|
120 |
# Detection Channel
|
121 |
-
|
122 |
# Frequency
|
123 |
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
|
124 |
-
# Spindle Frequency
|
125 |
-
spindle_freq = gr.Slider(10, 16, value=12, step=1, label="Spindle Frequency (Hz)", interactive=True)
|
126 |
-
# Spindle Detection Mode
|
127 |
-
spindle_detection_mode = gr.Dropdown(choices=["Fast", "Peak", "Valley"], value="Peak", label="Spindle Detection Mode", interactive=True)
|
128 |
-
# Time to buffer
|
129 |
-
time_to_buffer = gr.Slider(0, 1, value=0.3, step=0.01, label="Time to Buffer (s)", interactive=True)
|
130 |
|
131 |
-
# Output
|
|
|
132 |
output_plot = gr.Plot()
|
133 |
-
# Output file
|
134 |
output_array = gr.File(label="Output CSV File")
|
|
|
135 |
|
136 |
# Row containing all buttons:
|
137 |
with gr.Row():
|
@@ -139,7 +49,19 @@ with gr.Blocks() as demo:
|
|
139 |
run_inference = gr.Button(value="Run Inference")
|
140 |
# Reset button
|
141 |
reset = gr.Button(value="Reset", variant="secondary")
|
142 |
-
run_inference.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
demo.queue()
|
145 |
demo.launch(share=True)
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
from portiloop.src.demo.offline import run_offline
|
|
|
|
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
def on_upload_file(file):
|
7 |
+
# Check if file extension is .xdf
|
8 |
+
if file.name.split(".")[-1] != "xdf":
|
9 |
+
raise gr.Error("Please upload a .xdf file.")
|
10 |
+
else:
|
11 |
+
yield f"File {file.name} successfully uploaded!"
|
12 |
|
|
|
13 |
|
14 |
with gr.Blocks() as demo:
|
15 |
gr.Markdown("# Portiloop Demo")
|
16 |
gr.Markdown("This Demo takes as input a csv file containing EEG data and outputs a csv file with the following added: \n * The data filtered by the Portiloop online filter \n * The stimulations made by Portiloop.")
|
17 |
gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
|
18 |
|
|
|
19 |
with gr.Row():
|
20 |
+
xdf_file = gr.UploadButton(label="Upload XDF File", type="file")
|
21 |
+
|
22 |
+
# Offline Filtering (Boolean)
|
23 |
+
offline_filtering = gr.Checkbox(label="Offline Filtering (On/Off)", value=True)
|
24 |
+
# Online Filtering (Boolean)
|
25 |
+
online_filtering = gr.Checkbox(label="Online Filtering (On/Off)", value=True)
|
26 |
+
# Lacourse's Method (Boolean)
|
27 |
+
lacourse = gr.Checkbox(label="Lacourse Detection (On/Off)", value=True)
|
28 |
+
# Wamsley's Method (Boolean)
|
29 |
+
wamsley = gr.Checkbox(label="Wamsley Detection (On/Off)", value=True)
|
30 |
+
# Online Detection (Boolean)
|
31 |
+
online_detection = gr.Checkbox(label="Online Detection (On/Off)", value=True)
|
32 |
+
|
33 |
# Threshold value
|
34 |
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
|
35 |
# Detection Channel
|
36 |
+
detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
|
37 |
# Frequency
|
38 |
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
# Output elements
|
41 |
+
update_text = gr.Textbox(value="Waiting for user input...", label="Status", interactive=False)
|
42 |
output_plot = gr.Plot()
|
|
|
43 |
output_array = gr.File(label="Output CSV File")
|
44 |
+
xdf_file.upload(fn=on_upload_file, inputs=[xdf_file], outputs=[update_text])
|
45 |
|
46 |
# Row containing all buttons:
|
47 |
with gr.Row():
|
|
|
49 |
run_inference = gr.Button(value="Run Inference")
|
50 |
# Reset button
|
51 |
reset = gr.Button(value="Reset", variant="secondary")
|
52 |
+
run_inference.click(
|
53 |
+
fn=run_offline,
|
54 |
+
inputs=[
|
55 |
+
xdf_file,
|
56 |
+
offline_filtering,
|
57 |
+
online_filtering,
|
58 |
+
online_detection,
|
59 |
+
lacourse,
|
60 |
+
wamsley,
|
61 |
+
threshold,
|
62 |
+
detect_channel,
|
63 |
+
freq],
|
64 |
+
outputs=[output_plot, output_array, update_text])
|
65 |
|
66 |
demo.queue()
|
67 |
demo.launch(share=True)
|
portiloop/src/demo/demo_stimulator.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
from portiloop.src.stimulation import Stimulator
|
3 |
-
|
4 |
-
|
5 |
-
class DemoSleepSpindleRealTimeStimulator(Stimulator):
|
6 |
-
def __init__(self):
|
7 |
-
self.last_detected_ts = time.time()
|
8 |
-
self.wait_t = 0.4 # 400 ms
|
9 |
-
|
10 |
-
def stimulate(self, detection_signal):
|
11 |
-
stim = False
|
12 |
-
for sig in detection_signal:
|
13 |
-
# We detect a stimulation
|
14 |
-
if sig:
|
15 |
-
# Record time of stimulation
|
16 |
-
ts = time.time()
|
17 |
-
|
18 |
-
# Check if time since last stimulation is long enough
|
19 |
-
if ts - self.last_detected_ts > self.wait_t:
|
20 |
-
if self.delayer is not None:
|
21 |
-
# If we have a delayer, notify it
|
22 |
-
self.delayer.detected()
|
23 |
-
stim = True
|
24 |
-
|
25 |
-
self.last_detected_ts = ts
|
26 |
-
return stim
|
27 |
-
|
28 |
-
def add_delayer(self, delayer):
|
29 |
-
self.delayer = delayer
|
30 |
-
self.delayer.stimulate = lambda: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
portiloop/src/demo/offline.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
4 |
+
plt.switch_backend('agg')
|
5 |
+
from portiloop.src.processing import FilterPipeline
|
6 |
+
from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
def run_offline(xdf_file, offline_filtering, online_filtering, online_detection, lacourse, wamsley, threshold, channel_num, freq):
|
11 |
+
|
12 |
+
print("Starting offline processing...")
|
13 |
+
# Make sure the inputs make sense:
|
14 |
+
if not offline_filtering and (lacourse or wamsley):
|
15 |
+
raise gr.Error("You can't use the offline detection methods without offline filtering.")
|
16 |
+
|
17 |
+
if not online_filtering and online_detection:
|
18 |
+
raise gr.Error("You can't use the online detection without online filtering.")
|
19 |
+
|
20 |
+
freq = int(freq)
|
21 |
+
|
22 |
+
# Read the xdf file to a numpy array
|
23 |
+
print("Loading xdf file...")
|
24 |
+
yield None, None, "Loading xdf file..."
|
25 |
+
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
|
26 |
+
print(data_whole.shape)
|
27 |
+
# Do the offline filtering of the data
|
28 |
+
print("Filtering offline...")
|
29 |
+
yield None, None, "Filtering offline..."
|
30 |
+
if offline_filtering:
|
31 |
+
offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
|
32 |
+
# Expand the dimension of the filtered data to match the shape of the other columns
|
33 |
+
offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
|
34 |
+
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
|
35 |
+
columns.append("offline_filtered_signal")
|
36 |
+
|
37 |
+
# Do Wamsley's method
|
38 |
+
print("Running Wamsley detection...")
|
39 |
+
yield None, None, "Running Wamsley detection..."
|
40 |
+
if wamsley:
|
41 |
+
wamsley_data = offline_detect("Wamsley", \
|
42 |
+
data_whole[:, columns.index("offline_filtered_signal")],\
|
43 |
+
data_whole[:, columns.index("time_stamps")],\
|
44 |
+
freq)
|
45 |
+
wamsley_data = np.expand_dims(wamsley_data, axis=1)
|
46 |
+
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
|
47 |
+
columns.append("wamsley_spindles")
|
48 |
+
|
49 |
+
# Do Lacourse's method
|
50 |
+
print("Running Lacourse detection...")
|
51 |
+
yield None, None, "Running Lacourse detection..."
|
52 |
+
if lacourse:
|
53 |
+
lacourse_data = offline_detect("Lacourse", \
|
54 |
+
data_whole[:, columns.index("offline_filtered_signal")],\
|
55 |
+
data_whole[:, columns.index("time_stamps")],\
|
56 |
+
freq)
|
57 |
+
lacourse_data = np.expand_dims(lacourse_data, axis=1)
|
58 |
+
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
|
59 |
+
columns.append("lacourse_spindles")
|
60 |
+
|
61 |
+
# Get the data from the raw signal column
|
62 |
+
data = data_whole[:, columns.index("raw_signal")]
|
63 |
+
|
64 |
+
# Create the online filtering pipeline
|
65 |
+
if online_filtering:
|
66 |
+
filter = FilterPipeline(nb_channels=1, sampling_rate=freq)
|
67 |
+
|
68 |
+
# Create the detector
|
69 |
+
if online_detection:
|
70 |
+
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
71 |
+
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
72 |
+
|
73 |
+
print("Running online filtering and detection...")
|
74 |
+
yield None, None, "Running online filtering and detection..."
|
75 |
+
if online_filtering or online_detection:
|
76 |
+
# Plotting variables
|
77 |
+
points = []
|
78 |
+
online_activations = []
|
79 |
+
|
80 |
+
# Go through the data
|
81 |
+
for index, point in enumerate(data):
|
82 |
+
# Filter the data
|
83 |
+
if online_filtering:
|
84 |
+
filtered_point = filter.filter(np.array([point]))
|
85 |
+
else:
|
86 |
+
filtered_point = point
|
87 |
+
filtered_point = filtered_point.tolist()
|
88 |
+
points.append(filtered_point[0])
|
89 |
+
|
90 |
+
if online_detection:
|
91 |
+
# Detect the spindles
|
92 |
+
result = detector.detect([filtered_point])
|
93 |
+
|
94 |
+
# Stimulate if necessary
|
95 |
+
stim = stimulator.stimulate(result)
|
96 |
+
if stim:
|
97 |
+
online_activations.append(1)
|
98 |
+
else:
|
99 |
+
online_activations.append(0)
|
100 |
+
|
101 |
+
# Function to return a list of all indexes where activations have happened
|
102 |
+
def get_activations(activations):
|
103 |
+
return [i for i, x in enumerate(activations) if x == 1]
|
104 |
+
|
105 |
+
# Plot the data
|
106 |
+
if index % (10 * freq) == 0 and index >= (10 * freq):
|
107 |
+
plt.close()
|
108 |
+
fig = plt.figure(figsize=(20, 10))
|
109 |
+
plt.clf()
|
110 |
+
plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
|
111 |
+
# Draw vertical lines for activations
|
112 |
+
for index in get_activations(online_activations[-10 * freq:]):
|
113 |
+
plt.axvline(x=index / freq, color='r', label="Portiloop Stimulation")
|
114 |
+
# Add axis titles and legend
|
115 |
+
plt.legend()
|
116 |
+
plt.xlabel("Time (s)")
|
117 |
+
plt.ylabel("Amplitude")
|
118 |
+
yield fig, None, "Running online filtering and detection..."
|
119 |
+
|
120 |
+
if online_filtering:
|
121 |
+
online_filtered = np.array(points)
|
122 |
+
online_filtered = np.expand_dims(online_filtered, axis=1)
|
123 |
+
data_whole = np.concatenate((data_whole, online_filtered), axis=1)
|
124 |
+
columns.append("online_filtered_signal")
|
125 |
+
|
126 |
+
if online_detection:
|
127 |
+
online_activations = np.array(online_activations)
|
128 |
+
online_activations = np.expand_dims(online_activations, axis=1)
|
129 |
+
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
130 |
+
columns.append("online_stimulations")
|
131 |
+
|
132 |
+
print("Saving output...")
|
133 |
+
# Output the data to a csv file
|
134 |
+
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
135 |
+
|
136 |
+
print("Done!")
|
137 |
+
yield None, "output.csv", "Done!"
|
portiloop/src/demo/test_offline.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import unittest
|
3 |
+
from portiloop.src.demo.offline import run_offline
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
class TestOffline(unittest.TestCase):
|
8 |
+
|
9 |
+
def setUp(self):
|
10 |
+
combinatorial_config = {
|
11 |
+
'offline_filtering': [True, False],
|
12 |
+
'online_filtering': [True, False],
|
13 |
+
'online_detection': [True, False],
|
14 |
+
'wamsley': [True, False],
|
15 |
+
'lacourse': [True, False],
|
16 |
+
}
|
17 |
+
|
18 |
+
self.exclusives = [("duplicate_as_window", "use_cnn_encoder")]
|
19 |
+
|
20 |
+
keys = list(combinatorial_config)
|
21 |
+
all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
|
22 |
+
all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
|
23 |
+
self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
|
24 |
+
self.xdf_file = Path(__file__).parents[3] / "test_xdf" / "test_file.xdf"
|
25 |
+
|
26 |
+
|
27 |
+
def test_all_options(self):
|
28 |
+
for config in self.filtered_options:
|
29 |
+
if config['online_detection']:
|
30 |
+
self.assertTrue(config['online_filtering'])
|
31 |
+
|
32 |
+
def test_single_option(self):
|
33 |
+
res = list(run_offline(
|
34 |
+
self.xdf_file,
|
35 |
+
offline_filtering=True,
|
36 |
+
online_filtering=True,
|
37 |
+
online_detection=True,
|
38 |
+
wamsley=True,
|
39 |
+
lacourse=True,
|
40 |
+
threshold=0.5,
|
41 |
+
channel_num=2,
|
42 |
+
freq=250))
|
43 |
+
print(res)
|
44 |
+
|
45 |
+
def tearDown(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
unittest.main()
|
portiloop/src/demo/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pyxdf
|
3 |
+
from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012
|
4 |
+
from scipy.signal import butter, filtfilt, iirnotch, detrend
|
5 |
+
import time
|
6 |
+
from portiloop.src.stimulation import Stimulator
|
7 |
+
|
8 |
+
|
9 |
+
STREAM_NAMES = {
|
10 |
+
'filtered_data': 'Portiloop Filtered',
|
11 |
+
'raw_data': 'Portiloop Raw Data',
|
12 |
+
'stimuli': 'Portiloop_stimuli'
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
|
17 |
+
def __init__(self):
|
18 |
+
self.last_detected_ts = time.time()
|
19 |
+
self.wait_t = 0.4 # 400 ms
|
20 |
+
self.delayer = None
|
21 |
+
|
22 |
+
def stimulate(self, detection_signal):
|
23 |
+
stim = False
|
24 |
+
for sig in detection_signal:
|
25 |
+
# We detect a stimulation
|
26 |
+
if sig:
|
27 |
+
# Record time of stimulation
|
28 |
+
ts = time.time()
|
29 |
+
|
30 |
+
# Check if time since last stimulation is long enough
|
31 |
+
if ts - self.last_detected_ts > self.wait_t:
|
32 |
+
if self.delayer is not None:
|
33 |
+
# If we have a delayer, notify it
|
34 |
+
self.delayer.detected()
|
35 |
+
stim = True
|
36 |
+
|
37 |
+
self.last_detected_ts = ts
|
38 |
+
return stim
|
39 |
+
|
40 |
+
def add_delayer(self, delayer):
|
41 |
+
self.delayer = delayer
|
42 |
+
self.delayer.stimulate = lambda: True
|
43 |
+
|
44 |
+
def xdf2array(xdf_path, channel):
|
45 |
+
xdf_data, _ = pyxdf.load_xdf(xdf_path)
|
46 |
+
|
47 |
+
# Load all streams given their names
|
48 |
+
filtered_stream, raw_stream, markers = None, None, None
|
49 |
+
for stream in xdf_data:
|
50 |
+
# print(stream['info']['name'])
|
51 |
+
if stream['info']['name'][0] == STREAM_NAMES['filtered_data']:
|
52 |
+
filtered_stream = stream
|
53 |
+
elif stream['info']['name'][0] == STREAM_NAMES['raw_data']:
|
54 |
+
raw_stream = stream
|
55 |
+
elif stream['info']['name'][0] == STREAM_NAMES['stimuli']:
|
56 |
+
markers = stream
|
57 |
+
|
58 |
+
if filtered_stream is None or raw_stream is None:
|
59 |
+
raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording")
|
60 |
+
|
61 |
+
# Add all samples from raw and filtered signals
|
62 |
+
csv_list = []
|
63 |
+
diffs = []
|
64 |
+
shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]),
|
65 |
+
int(raw_stream['footer']['info']['sample_count'][0]))
|
66 |
+
for i in range(shortest_stream):
|
67 |
+
if markers is not None:
|
68 |
+
datapoint = [filtered_stream['time_stamps'][i],
|
69 |
+
float(filtered_stream['time_series'][i, channel-1]),
|
70 |
+
raw_stream['time_series'][i, channel-1],
|
71 |
+
0]
|
72 |
+
else:
|
73 |
+
datapoint = [filtered_stream['time_stamps'][i],
|
74 |
+
float(filtered_stream['time_series'][i, channel-1]),
|
75 |
+
raw_stream['time_series'][i, channel-1]]
|
76 |
+
diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i]))
|
77 |
+
csv_list.append(datapoint)
|
78 |
+
|
79 |
+
# Add markers
|
80 |
+
columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"]
|
81 |
+
if markers is not None:
|
82 |
+
columns.append("online_stimulations_portiloop")
|
83 |
+
for time_stamp in markers['time_stamps']:
|
84 |
+
new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin()
|
85 |
+
csv_list[new_index][3] = 1
|
86 |
+
|
87 |
+
return np.array(csv_list), columns
|
88 |
+
|
89 |
+
|
90 |
+
def offline_detect(method, data, timesteps, freq):
|
91 |
+
# Get the spindle data from the offline methods
|
92 |
+
time = np.arange(0, len(data)) / freq
|
93 |
+
if method == "Lacourse":
|
94 |
+
detector = DetectSpindle(method='Lacourse2018')
|
95 |
+
spindles, _, _ = detect_Lacourse2018(data, freq, time, detector)
|
96 |
+
elif method == "Wamsley":
|
97 |
+
detector = DetectSpindle(method='Wamsley2012')
|
98 |
+
spindles, _, _ = detect_Wamsley2012(data, freq, time, detector)
|
99 |
+
else:
|
100 |
+
raise ValueError("Invalid method")
|
101 |
+
|
102 |
+
# Convert the spindle data to a numpy array
|
103 |
+
spindle_result = np.zeros(data.shape)
|
104 |
+
for spindle in spindles:
|
105 |
+
start = spindle["start"]
|
106 |
+
end = spindle["end"]
|
107 |
+
# Find index of timestep closest to start and end
|
108 |
+
start_index = np.argmin(np.abs(timesteps - start))
|
109 |
+
end_index = np.argmin(np.abs(timesteps - end))
|
110 |
+
spindle_result[start_index:end_index] = 1
|
111 |
+
return spindle_result
|
112 |
+
|
113 |
+
|
114 |
+
def offline_filter(signal, freq):
|
115 |
+
|
116 |
+
# Notch filter
|
117 |
+
f0 = 60.0 # Frequency to be removed from signal (Hz)
|
118 |
+
Q = 100.0 # Quality factor
|
119 |
+
b, a = iirnotch(f0, Q, freq)
|
120 |
+
signal = filtfilt(b, a, signal)
|
121 |
+
|
122 |
+
# Bandpass filter
|
123 |
+
lowcut = 0.5
|
124 |
+
highcut = 40.0
|
125 |
+
order = 4
|
126 |
+
b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass')
|
127 |
+
signal = filtfilt(b, a, signal)
|
128 |
+
|
129 |
+
# Detrend the signal
|
130 |
+
signal = detrend(signal)
|
131 |
+
|
132 |
+
return signal
|
portiloop/src/detection.py
CHANGED
@@ -154,7 +154,3 @@ class SleepSpindleRealTimeDetector(Detector):
|
|
154 |
print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
|
155 |
|
156 |
return output_data_y, output_data_h
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
154 |
print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
|
155 |
|
156 |
return output_data_y, output_data_h
|
|
|
|
|
|
|
|
setup.py
CHANGED
@@ -20,6 +20,7 @@ setup(
|
|
20 |
'pylsl-coral',
|
21 |
'pyalsaaudio'],
|
22 |
'PC': ['gradio',
|
23 |
-
'tensorflow',
|
|
|
24 |
},
|
25 |
)
|
|
|
20 |
'pylsl-coral',
|
21 |
'pyalsaaudio'],
|
22 |
'PC': ['gradio',
|
23 |
+
'tensorflow',
|
24 |
+
'pyxdf']
|
25 |
},
|
26 |
)
|