Spaces:
Sleeping
Sleeping
ybouteiller
commited on
Commit
·
120f728
1
Parent(s):
4711af1
Added inference code
Browse files- portiloop/capture.py +67 -3
- portiloop/inference.py +95 -0
- portiloop/notebooks/tests.ipynb +34 -3
- setup.py +2 -1
portiloop/capture.py
CHANGED
@@ -399,7 +399,7 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
|
|
399 |
|
400 |
|
401 |
class Capture:
|
402 |
-
def __init__(self):
|
403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
404 |
self.filename = EDF_PATH / 'recording.edf'
|
405 |
self._p_capture = None
|
@@ -415,6 +415,9 @@ class Capture:
|
|
415 |
self.custom_fir_cutoff = 30
|
416 |
self.filter = True
|
417 |
self.record = False
|
|
|
|
|
|
|
418 |
self.lsl = False
|
419 |
self.display = False
|
420 |
self.python_clock = True
|
@@ -430,6 +433,8 @@ class Capture:
|
|
430 |
self._t_capture = None
|
431 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
432 |
|
|
|
|
|
433 |
# widgets ===============================
|
434 |
|
435 |
# CHANNELS ------------------------------
|
@@ -554,6 +559,12 @@ class Capture:
|
|
554 |
disabled=False
|
555 |
)
|
556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
self.b_polyak_mean = widgets.FloatText(
|
558 |
value=self.polyak_mean,
|
559 |
description='Polyak mean:',
|
@@ -610,6 +621,20 @@ class Capture:
|
|
610 |
indent=False
|
611 |
)
|
612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
self.b_record = widgets.Checkbox(
|
614 |
value=self.record,
|
615 |
description='Record EDF',
|
@@ -636,8 +661,10 @@ class Capture:
|
|
636 |
self.b_capture.observe(self.on_b_capture, 'value')
|
637 |
self.b_clock.observe(self.on_b_clock, 'value')
|
638 |
self.b_frequency.observe(self.on_b_frequency, 'value')
|
|
|
639 |
self.b_duration.observe(self.on_b_duration, 'value')
|
640 |
self.b_filter.observe(self.on_b_filter, 'value')
|
|
|
641 |
self.b_record.observe(self.on_b_record, 'value')
|
642 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
643 |
self.b_display.observe(self.on_b_display, 'value')
|
@@ -668,7 +695,8 @@ class Capture:
|
|
668 |
self.b_filename,
|
669 |
self.b_power_line,
|
670 |
self.b_clock,
|
671 |
-
widgets.HBox([self.b_filter, self.b_record, self.b_lsl, self.b_display]),
|
|
|
672 |
self.b_accordion_filter,
|
673 |
self.b_capture]))
|
674 |
|
@@ -677,6 +705,7 @@ class Capture:
|
|
677 |
self.b_duration.disabled = False
|
678 |
self.b_filename.disabled = False
|
679 |
self.b_filter.disabled = False
|
|
|
680 |
self.b_record.disabled = False
|
681 |
self.b_record.lsl = False
|
682 |
self.b_display.disabled = False
|
@@ -694,12 +723,16 @@ class Capture:
|
|
694 |
self.b_custom_fir.disabled = False
|
695 |
self.b_custom_fir_order.disabled = not self.custom_fir
|
696 |
self.b_custom_fir_cutoff.disabled = not self.custom_fir
|
|
|
|
|
697 |
|
698 |
def disable_buttons(self):
|
699 |
self.b_frequency.disabled = True
|
700 |
self.b_duration.disabled = True
|
701 |
self.b_filename.disabled = True
|
702 |
self.b_filter.disabled = True
|
|
|
|
|
703 |
self.b_record.disabled = True
|
704 |
self.b_record.lsl = True
|
705 |
self.b_display.disabled = True
|
@@ -717,6 +750,7 @@ class Capture:
|
|
717 |
self.b_custom_fir.disabled = True
|
718 |
self.b_custom_fir_order.disabled = True
|
719 |
self.b_custom_fir_cutoff.disabled = True
|
|
|
720 |
|
721 |
def on_b_radio_ch2(self, value):
|
722 |
self.channel_states[1] = value['new']
|
@@ -751,7 +785,7 @@ class Capture:
|
|
751 |
warnings.warn("Capture already running, operation aborted.")
|
752 |
return
|
753 |
self._t_capture = Thread(target=self.start_capture,
|
754 |
-
args=(self.filter, self.record, self.lsl, self.display, 500, self.python_clock))
|
755 |
self._t_capture.start()
|
756 |
elif val == 'Stop':
|
757 |
with self._lock_msg_out:
|
@@ -790,6 +824,13 @@ class Capture:
|
|
790 |
else:
|
791 |
self.b_frequency.value = self.frequency
|
792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
def on_b_filename(self, value):
|
794 |
val = value['new']
|
795 |
if val != '':
|
@@ -844,6 +885,15 @@ class Capture:
|
|
844 |
val = value['new']
|
845 |
self.filter = val
|
846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
847 |
def on_b_record(self, value):
|
848 |
val = value['new']
|
849 |
self.record = val
|
@@ -894,6 +944,8 @@ class Capture:
|
|
894 |
|
895 |
def start_capture(self,
|
896 |
filter,
|
|
|
|
|
897 |
record,
|
898 |
lsl,
|
899 |
viz,
|
@@ -918,6 +970,9 @@ class Capture:
|
|
918 |
alpha_avg=self.polyak_mean,
|
919 |
alpha_std=self.polyak_std,
|
920 |
epsilon=self.epsilon)
|
|
|
|
|
|
|
921 |
|
922 |
self._p_capture = mp.Process(target=_capture_process,
|
923 |
args=(p_data_o,
|
@@ -975,6 +1030,15 @@ class Capture:
|
|
975 |
|
976 |
filtered_point = n_array.tolist()
|
977 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
978 |
if lsl:
|
979 |
lsl_outlet.push_sample(filtered_point[-1])
|
980 |
|
|
|
399 |
|
400 |
|
401 |
class Capture:
|
402 |
+
def __init__(self, quantInferenceClass):
|
403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
404 |
self.filename = EDF_PATH / 'recording.edf'
|
405 |
self._p_capture = None
|
|
|
415 |
self.custom_fir_cutoff = 30
|
416 |
self.filter = True
|
417 |
self.record = False
|
418 |
+
self.detect = False
|
419 |
+
self.stimulate = False
|
420 |
+
self.threshold = 0.5
|
421 |
self.lsl = False
|
422 |
self.display = False
|
423 |
self.python_clock = True
|
|
|
433 |
self._t_capture = None
|
434 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
435 |
|
436 |
+
self.quantInferenceClass = quantInferenceClass
|
437 |
+
|
438 |
# widgets ===============================
|
439 |
|
440 |
# CHANNELS ------------------------------
|
|
|
559 |
disabled=False
|
560 |
)
|
561 |
|
562 |
+
self.b_threshold = widgets.FloatText(
|
563 |
+
value=self.threshold,
|
564 |
+
description='Threshold:',
|
565 |
+
disabled=True
|
566 |
+
)
|
567 |
+
|
568 |
self.b_polyak_mean = widgets.FloatText(
|
569 |
value=self.polyak_mean,
|
570 |
description='Polyak mean:',
|
|
|
621 |
indent=False
|
622 |
)
|
623 |
|
624 |
+
self.b_detect = widgets.Checkbox(
|
625 |
+
value=self.detect,
|
626 |
+
description='Detect',
|
627 |
+
disabled=False,
|
628 |
+
indent=False
|
629 |
+
)
|
630 |
+
|
631 |
+
self.b_stimulate = widgets.Checkbox(
|
632 |
+
value=self.stimulate,
|
633 |
+
description='Stimulate',
|
634 |
+
disabled=True,
|
635 |
+
indent=False
|
636 |
+
)
|
637 |
+
|
638 |
self.b_record = widgets.Checkbox(
|
639 |
value=self.record,
|
640 |
description='Record EDF',
|
|
|
661 |
self.b_capture.observe(self.on_b_capture, 'value')
|
662 |
self.b_clock.observe(self.on_b_clock, 'value')
|
663 |
self.b_frequency.observe(self.on_b_frequency, 'value')
|
664 |
+
self.b_threshold.observe(self.on_b_threshold, 'value')
|
665 |
self.b_duration.observe(self.on_b_duration, 'value')
|
666 |
self.b_filter.observe(self.on_b_filter, 'value')
|
667 |
+
self.b_detect.observe(self.on_b_detect, 'value')
|
668 |
self.b_record.observe(self.on_b_record, 'value')
|
669 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
670 |
self.b_display.observe(self.on_b_display, 'value')
|
|
|
695 |
self.b_filename,
|
696 |
self.b_power_line,
|
697 |
self.b_clock,
|
698 |
+
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
699 |
+
self.b_threshold,
|
700 |
self.b_accordion_filter,
|
701 |
self.b_capture]))
|
702 |
|
|
|
705 |
self.b_duration.disabled = False
|
706 |
self.b_filename.disabled = False
|
707 |
self.b_filter.disabled = False
|
708 |
+
self.b_detect.disabled = False
|
709 |
self.b_record.disabled = False
|
710 |
self.b_record.lsl = False
|
711 |
self.b_display.disabled = False
|
|
|
723 |
self.b_custom_fir.disabled = False
|
724 |
self.b_custom_fir_order.disabled = not self.custom_fir
|
725 |
self.b_custom_fir_cutoff.disabled = not self.custom_fir
|
726 |
+
self.b_stimulate.disabled = not self.detect
|
727 |
+
self.b_threshold.disabled = not self.detect
|
728 |
|
729 |
def disable_buttons(self):
|
730 |
self.b_frequency.disabled = True
|
731 |
self.b_duration.disabled = True
|
732 |
self.b_filename.disabled = True
|
733 |
self.b_filter.disabled = True
|
734 |
+
self.b_stimulate.disabled = True
|
735 |
+
self.b_filter.disabled = True
|
736 |
self.b_record.disabled = True
|
737 |
self.b_record.lsl = True
|
738 |
self.b_display.disabled = True
|
|
|
750 |
self.b_custom_fir.disabled = True
|
751 |
self.b_custom_fir_order.disabled = True
|
752 |
self.b_custom_fir_cutoff.disabled = True
|
753 |
+
self.b_threshold.disabled = True
|
754 |
|
755 |
def on_b_radio_ch2(self, value):
|
756 |
self.channel_states[1] = value['new']
|
|
|
785 |
warnings.warn("Capture already running, operation aborted.")
|
786 |
return
|
787 |
self._t_capture = Thread(target=self.start_capture,
|
788 |
+
args=(self.filter, self.detect, self.quantInferenceClass, self.record, self.lsl, self.display, 500, self.python_clock))
|
789 |
self._t_capture.start()
|
790 |
elif val == 'Stop':
|
791 |
with self._lock_msg_out:
|
|
|
824 |
else:
|
825 |
self.b_frequency.value = self.frequency
|
826 |
|
827 |
+
def on_b_threshold(self, value):
|
828 |
+
val = value['new']
|
829 |
+
if val >= 0 and val <= 1:
|
830 |
+
self.threshold = val
|
831 |
+
else:
|
832 |
+
self.b_threshold.value = self.threshold
|
833 |
+
|
834 |
def on_b_filename(self, value):
|
835 |
val = value['new']
|
836 |
if val != '':
|
|
|
885 |
val = value['new']
|
886 |
self.filter = val
|
887 |
|
888 |
+
def on_b_stimulate(self, value):
|
889 |
+
val = value['new']
|
890 |
+
self.stimulate = val
|
891 |
+
|
892 |
+
def on_b_detect(self, value):
|
893 |
+
val = value['new']
|
894 |
+
self.detect = val
|
895 |
+
self.enable_buttons()
|
896 |
+
|
897 |
def on_b_record(self, value):
|
898 |
val = value['new']
|
899 |
self.record = val
|
|
|
944 |
|
945 |
def start_capture(self,
|
946 |
filter,
|
947 |
+
detect,
|
948 |
+
quantInferenceClass,
|
949 |
record,
|
950 |
lsl,
|
951 |
viz,
|
|
|
970 |
alpha_avg=self.polyak_mean,
|
971 |
alpha_std=self.polyak_std,
|
972 |
epsilon=self.epsilon)
|
973 |
+
|
974 |
+
if detect:
|
975 |
+
infer = quantInferenceClass()
|
976 |
|
977 |
self._p_capture = mp.Process(target=_capture_process,
|
978 |
args=(p_data_o,
|
|
|
1030 |
|
1031 |
filtered_point = n_array.tolist()
|
1032 |
|
1033 |
+
if detect:
|
1034 |
+
results = infer.add_datapoints(filtered_points)
|
1035 |
+
|
1036 |
+
for r in results:
|
1037 |
+
print(r >= threshold)
|
1038 |
+
|
1039 |
+
if stimulate and True:
|
1040 |
+
print('stimulation')
|
1041 |
+
|
1042 |
if lsl:
|
1043 |
lsl_outlet.push_sample(filtered_point[-1])
|
1044 |
|
portiloop/inference.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycoral.utils import edgetpu
|
2 |
+
import time
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from pathlib import Path
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
|
8 |
+
print(DEFAULT_MODEL_PATH)
|
9 |
+
|
10 |
+
class AbstractQuantizedModelForInference(ABC):
|
11 |
+
@abstractmethod
|
12 |
+
def add_datapoints(self, input_float):
|
13 |
+
return NotImplemented
|
14 |
+
|
15 |
+
class QuantizedModelForInference(AbstractQuantizedModelForInference):
|
16 |
+
def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
|
17 |
+
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
|
18 |
+
self.verbose = verbose
|
19 |
+
self.channel = channel
|
20 |
+
self.num_models_parallel = num_models_parallel
|
21 |
+
|
22 |
+
self.interpreters = []
|
23 |
+
for i in range(self.num_models_parallel):
|
24 |
+
self.interpreters.append(edgetpu.make_interpreter(model_path))
|
25 |
+
self.interpreters[i].allocate_tensors()
|
26 |
+
self.interpreter_counter = 0
|
27 |
+
|
28 |
+
self.input_details = self.interpreters[0].get_input_details()
|
29 |
+
self.output_details = self.interpreters[0].get_output_details()
|
30 |
+
|
31 |
+
self.buffer = []
|
32 |
+
self.seq_stride = seq_stride
|
33 |
+
self.window_size = window_size
|
34 |
+
|
35 |
+
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
|
36 |
+
for idx, i in enumerate(self.stride_counters[1:]):
|
37 |
+
self.stride_counters[idx+1] = i - self.stride_counters[idx]
|
38 |
+
self.current_stride_counter = self.stride_counters[0] - 1
|
39 |
+
|
40 |
+
|
41 |
+
def add_datapoints(self, inputs_float):
|
42 |
+
res = []
|
43 |
+
for inp in inputs_float:
|
44 |
+
result = self.add_datapoint(inp)
|
45 |
+
if result is not None:
|
46 |
+
res.append(result)
|
47 |
+
return res
|
48 |
+
|
49 |
+
|
50 |
+
def add_datapoint(self, input_float):
|
51 |
+
input_float = input_float[self.channel-1]
|
52 |
+
result = None
|
53 |
+
self.buffer.append(input_float)
|
54 |
+
if len(self.buffer) > self.window_size:
|
55 |
+
self.buffer = self.buffer[1:]
|
56 |
+
self.current_stride_counter += 1
|
57 |
+
if self.current_stride_counter == self.stride_counter[self.interpreter_counter]:
|
58 |
+
result = self.call_model(self.interpreter_counter, self.buffer)
|
59 |
+
self.interpreter_counter += 1
|
60 |
+
self.interpreter_counter %= self.num_model_parallel
|
61 |
+
self.current_stride_counter = 0
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def call_model(self, idx, input_float=None):
|
67 |
+
if input_float is None:
|
68 |
+
# For debuggin purposes
|
69 |
+
input_shape = input_details[0]['shape']
|
70 |
+
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
71 |
+
else:
|
72 |
+
# Convert float input to Int
|
73 |
+
input_scale, input_zero_point = input_details[0]["quantization"]
|
74 |
+
input = np.asarray(input_float) / input_scale + input_zero_point
|
75 |
+
input = input.astype(input_details[0]["dtype"])
|
76 |
+
|
77 |
+
interpreter.set_tensor(input_details[0]['index'], input)
|
78 |
+
if self.verbose:
|
79 |
+
start_time = time.time()
|
80 |
+
|
81 |
+
interpreter.invoke()
|
82 |
+
|
83 |
+
if self.verbose:
|
84 |
+
end_time = time.time()
|
85 |
+
|
86 |
+
output = interpreter.get_tensor(output_details[0]['index'])
|
87 |
+
output_scale, output_zero_point = input_details[0]["quantization"]
|
88 |
+
output = float(output - output_zero_point) * output_scale
|
89 |
+
|
90 |
+
if self.verbose:
|
91 |
+
print(f"Computed output {output} in {end_time - start_time} seconds")
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
portiloop/notebooks/tests.ipynb
CHANGED
@@ -2,16 +2,47 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "7b2fc5da",
|
7 |
"metadata": {
|
8 |
"scrolled": false
|
9 |
},
|
10 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"source": [
|
12 |
"from portiloop.capture import Capture\n",
|
|
|
13 |
"\n",
|
14 |
-
"cap = Capture()"
|
15 |
]
|
16 |
}
|
17 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
"id": "7b2fc5da",
|
7 |
"metadata": {
|
8 |
"scrolled": false
|
9 |
},
|
10 |
+
"outputs": [
|
11 |
+
{
|
12 |
+
"data": {
|
13 |
+
"application/vnd.jupyter.widget-view+json": {
|
14 |
+
"model_id": "910f8e489b6341119f4d6e17a5b2aedc",
|
15 |
+
"version_major": 2,
|
16 |
+
"version_minor": 0
|
17 |
+
},
|
18 |
+
"text/plain": [
|
19 |
+
"VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
"metadata": {},
|
23 |
+
"output_type": "display_data"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "stderr",
|
27 |
+
"output_type": "stream",
|
28 |
+
"text": [
|
29 |
+
"Process Process-1:\n",
|
30 |
+
"Traceback (most recent call last):\n",
|
31 |
+
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
|
32 |
+
" self.run()\n",
|
33 |
+
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
|
34 |
+
" self._target(*self._args, **self._kwargs)\n",
|
35 |
+
" File \"/home/mendel/software/portiloop-software/portiloop/capture.py\", line 325, in _capture_process\n",
|
36 |
+
" assert data == [0x3E], \"The communication with the ADS cannot be established.\"\n",
|
37 |
+
"AssertionError: The communication with the ADS cannot be established.\n"
|
38 |
+
]
|
39 |
+
}
|
40 |
+
],
|
41 |
"source": [
|
42 |
"from portiloop.capture import Capture\n",
|
43 |
+
"from portiloop.inference import QuantizedModelForInference\n",
|
44 |
"\n",
|
45 |
+
"cap = Capture(QuantizedModelForInference)"
|
46 |
]
|
47 |
}
|
48 |
],
|
setup.py
CHANGED
@@ -14,6 +14,7 @@ setup(
|
|
14 |
'python-periphery',
|
15 |
'spidev',
|
16 |
'pylsl-coral',
|
17 |
-
'scipy'
|
|
|
18 |
]
|
19 |
)
|
|
|
14 |
'python-periphery',
|
15 |
'spidev',
|
16 |
'pylsl-coral',
|
17 |
+
'scipy',
|
18 |
+
'pycoral'
|
19 |
]
|
20 |
)
|