ybouteiller commited on
Commit
120f728
·
1 Parent(s): 4711af1

Added inference code

Browse files
portiloop/capture.py CHANGED
@@ -399,7 +399,7 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
399
 
400
 
401
  class Capture:
402
- def __init__(self):
403
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
404
  self.filename = EDF_PATH / 'recording.edf'
405
  self._p_capture = None
@@ -415,6 +415,9 @@ class Capture:
415
  self.custom_fir_cutoff = 30
416
  self.filter = True
417
  self.record = False
 
 
 
418
  self.lsl = False
419
  self.display = False
420
  self.python_clock = True
@@ -430,6 +433,8 @@ class Capture:
430
  self._t_capture = None
431
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
432
 
 
 
433
  # widgets ===============================
434
 
435
  # CHANNELS ------------------------------
@@ -554,6 +559,12 @@ class Capture:
554
  disabled=False
555
  )
556
 
 
 
 
 
 
 
557
  self.b_polyak_mean = widgets.FloatText(
558
  value=self.polyak_mean,
559
  description='Polyak mean:',
@@ -610,6 +621,20 @@ class Capture:
610
  indent=False
611
  )
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  self.b_record = widgets.Checkbox(
614
  value=self.record,
615
  description='Record EDF',
@@ -636,8 +661,10 @@ class Capture:
636
  self.b_capture.observe(self.on_b_capture, 'value')
637
  self.b_clock.observe(self.on_b_clock, 'value')
638
  self.b_frequency.observe(self.on_b_frequency, 'value')
 
639
  self.b_duration.observe(self.on_b_duration, 'value')
640
  self.b_filter.observe(self.on_b_filter, 'value')
 
641
  self.b_record.observe(self.on_b_record, 'value')
642
  self.b_lsl.observe(self.on_b_lsl, 'value')
643
  self.b_display.observe(self.on_b_display, 'value')
@@ -668,7 +695,8 @@ class Capture:
668
  self.b_filename,
669
  self.b_power_line,
670
  self.b_clock,
671
- widgets.HBox([self.b_filter, self.b_record, self.b_lsl, self.b_display]),
 
672
  self.b_accordion_filter,
673
  self.b_capture]))
674
 
@@ -677,6 +705,7 @@ class Capture:
677
  self.b_duration.disabled = False
678
  self.b_filename.disabled = False
679
  self.b_filter.disabled = False
 
680
  self.b_record.disabled = False
681
  self.b_record.lsl = False
682
  self.b_display.disabled = False
@@ -694,12 +723,16 @@ class Capture:
694
  self.b_custom_fir.disabled = False
695
  self.b_custom_fir_order.disabled = not self.custom_fir
696
  self.b_custom_fir_cutoff.disabled = not self.custom_fir
 
 
697
 
698
  def disable_buttons(self):
699
  self.b_frequency.disabled = True
700
  self.b_duration.disabled = True
701
  self.b_filename.disabled = True
702
  self.b_filter.disabled = True
 
 
703
  self.b_record.disabled = True
704
  self.b_record.lsl = True
705
  self.b_display.disabled = True
@@ -717,6 +750,7 @@ class Capture:
717
  self.b_custom_fir.disabled = True
718
  self.b_custom_fir_order.disabled = True
719
  self.b_custom_fir_cutoff.disabled = True
 
720
 
721
  def on_b_radio_ch2(self, value):
722
  self.channel_states[1] = value['new']
@@ -751,7 +785,7 @@ class Capture:
751
  warnings.warn("Capture already running, operation aborted.")
752
  return
753
  self._t_capture = Thread(target=self.start_capture,
754
- args=(self.filter, self.record, self.lsl, self.display, 500, self.python_clock))
755
  self._t_capture.start()
756
  elif val == 'Stop':
757
  with self._lock_msg_out:
@@ -790,6 +824,13 @@ class Capture:
790
  else:
791
  self.b_frequency.value = self.frequency
792
 
 
 
 
 
 
 
 
793
  def on_b_filename(self, value):
794
  val = value['new']
795
  if val != '':
@@ -844,6 +885,15 @@ class Capture:
844
  val = value['new']
845
  self.filter = val
846
 
 
 
 
 
 
 
 
 
 
847
  def on_b_record(self, value):
848
  val = value['new']
849
  self.record = val
@@ -894,6 +944,8 @@ class Capture:
894
 
895
  def start_capture(self,
896
  filter,
 
 
897
  record,
898
  lsl,
899
  viz,
@@ -918,6 +970,9 @@ class Capture:
918
  alpha_avg=self.polyak_mean,
919
  alpha_std=self.polyak_std,
920
  epsilon=self.epsilon)
 
 
 
921
 
922
  self._p_capture = mp.Process(target=_capture_process,
923
  args=(p_data_o,
@@ -975,6 +1030,15 @@ class Capture:
975
 
976
  filtered_point = n_array.tolist()
977
 
 
 
 
 
 
 
 
 
 
978
  if lsl:
979
  lsl_outlet.push_sample(filtered_point[-1])
980
 
 
399
 
400
 
401
  class Capture:
402
+ def __init__(self, quantInferenceClass):
403
  # {now.strftime('%m_%d_%Y_%H_%M_%S')}
404
  self.filename = EDF_PATH / 'recording.edf'
405
  self._p_capture = None
 
415
  self.custom_fir_cutoff = 30
416
  self.filter = True
417
  self.record = False
418
+ self.detect = False
419
+ self.stimulate = False
420
+ self.threshold = 0.5
421
  self.lsl = False
422
  self.display = False
423
  self.python_clock = True
 
433
  self._t_capture = None
434
  self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
435
 
436
+ self.quantInferenceClass = quantInferenceClass
437
+
438
  # widgets ===============================
439
 
440
  # CHANNELS ------------------------------
 
559
  disabled=False
560
  )
561
 
562
+ self.b_threshold = widgets.FloatText(
563
+ value=self.threshold,
564
+ description='Threshold:',
565
+ disabled=True
566
+ )
567
+
568
  self.b_polyak_mean = widgets.FloatText(
569
  value=self.polyak_mean,
570
  description='Polyak mean:',
 
621
  indent=False
622
  )
623
 
624
+ self.b_detect = widgets.Checkbox(
625
+ value=self.detect,
626
+ description='Detect',
627
+ disabled=False,
628
+ indent=False
629
+ )
630
+
631
+ self.b_stimulate = widgets.Checkbox(
632
+ value=self.stimulate,
633
+ description='Stimulate',
634
+ disabled=True,
635
+ indent=False
636
+ )
637
+
638
  self.b_record = widgets.Checkbox(
639
  value=self.record,
640
  description='Record EDF',
 
661
  self.b_capture.observe(self.on_b_capture, 'value')
662
  self.b_clock.observe(self.on_b_clock, 'value')
663
  self.b_frequency.observe(self.on_b_frequency, 'value')
664
+ self.b_threshold.observe(self.on_b_threshold, 'value')
665
  self.b_duration.observe(self.on_b_duration, 'value')
666
  self.b_filter.observe(self.on_b_filter, 'value')
667
+ self.b_detect.observe(self.on_b_detect, 'value')
668
  self.b_record.observe(self.on_b_record, 'value')
669
  self.b_lsl.observe(self.on_b_lsl, 'value')
670
  self.b_display.observe(self.on_b_display, 'value')
 
695
  self.b_filename,
696
  self.b_power_line,
697
  self.b_clock,
698
+ widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
699
+ self.b_threshold,
700
  self.b_accordion_filter,
701
  self.b_capture]))
702
 
 
705
  self.b_duration.disabled = False
706
  self.b_filename.disabled = False
707
  self.b_filter.disabled = False
708
+ self.b_detect.disabled = False
709
  self.b_record.disabled = False
710
  self.b_record.lsl = False
711
  self.b_display.disabled = False
 
723
  self.b_custom_fir.disabled = False
724
  self.b_custom_fir_order.disabled = not self.custom_fir
725
  self.b_custom_fir_cutoff.disabled = not self.custom_fir
726
+ self.b_stimulate.disabled = not self.detect
727
+ self.b_threshold.disabled = not self.detect
728
 
729
  def disable_buttons(self):
730
  self.b_frequency.disabled = True
731
  self.b_duration.disabled = True
732
  self.b_filename.disabled = True
733
  self.b_filter.disabled = True
734
+ self.b_stimulate.disabled = True
735
+ self.b_filter.disabled = True
736
  self.b_record.disabled = True
737
  self.b_record.lsl = True
738
  self.b_display.disabled = True
 
750
  self.b_custom_fir.disabled = True
751
  self.b_custom_fir_order.disabled = True
752
  self.b_custom_fir_cutoff.disabled = True
753
+ self.b_threshold.disabled = True
754
 
755
  def on_b_radio_ch2(self, value):
756
  self.channel_states[1] = value['new']
 
785
  warnings.warn("Capture already running, operation aborted.")
786
  return
787
  self._t_capture = Thread(target=self.start_capture,
788
+ args=(self.filter, self.detect, self.quantInferenceClass, self.record, self.lsl, self.display, 500, self.python_clock))
789
  self._t_capture.start()
790
  elif val == 'Stop':
791
  with self._lock_msg_out:
 
824
  else:
825
  self.b_frequency.value = self.frequency
826
 
827
+ def on_b_threshold(self, value):
828
+ val = value['new']
829
+ if val >= 0 and val <= 1:
830
+ self.threshold = val
831
+ else:
832
+ self.b_threshold.value = self.threshold
833
+
834
  def on_b_filename(self, value):
835
  val = value['new']
836
  if val != '':
 
885
  val = value['new']
886
  self.filter = val
887
 
888
+ def on_b_stimulate(self, value):
889
+ val = value['new']
890
+ self.stimulate = val
891
+
892
+ def on_b_detect(self, value):
893
+ val = value['new']
894
+ self.detect = val
895
+ self.enable_buttons()
896
+
897
  def on_b_record(self, value):
898
  val = value['new']
899
  self.record = val
 
944
 
945
  def start_capture(self,
946
  filter,
947
+ detect,
948
+ quantInferenceClass,
949
  record,
950
  lsl,
951
  viz,
 
970
  alpha_avg=self.polyak_mean,
971
  alpha_std=self.polyak_std,
972
  epsilon=self.epsilon)
973
+
974
+ if detect:
975
+ infer = quantInferenceClass()
976
 
977
  self._p_capture = mp.Process(target=_capture_process,
978
  args=(p_data_o,
 
1030
 
1031
  filtered_point = n_array.tolist()
1032
 
1033
+ if detect:
1034
+ results = infer.add_datapoints(filtered_points)
1035
+
1036
+ for r in results:
1037
+ print(r >= threshold)
1038
+
1039
+ if stimulate and True:
1040
+ print('stimulation')
1041
+
1042
  if lsl:
1043
  lsl_outlet.push_sample(filtered_point[-1])
1044
 
portiloop/inference.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pycoral.utils import edgetpu
2
+ import time
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+ import numpy as np
6
+
7
+ DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
8
+ print(DEFAULT_MODEL_PATH)
9
+
10
+ class AbstractQuantizedModelForInference(ABC):
11
+ @abstractmethod
12
+ def add_datapoints(self, input_float):
13
+ return NotImplemented
14
+
15
+ class QuantizedModelForInference(AbstractQuantizedModelForInference):
16
+ def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
17
+ model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
18
+ self.verbose = verbose
19
+ self.channel = channel
20
+ self.num_models_parallel = num_models_parallel
21
+
22
+ self.interpreters = []
23
+ for i in range(self.num_models_parallel):
24
+ self.interpreters.append(edgetpu.make_interpreter(model_path))
25
+ self.interpreters[i].allocate_tensors()
26
+ self.interpreter_counter = 0
27
+
28
+ self.input_details = self.interpreters[0].get_input_details()
29
+ self.output_details = self.interpreters[0].get_output_details()
30
+
31
+ self.buffer = []
32
+ self.seq_stride = seq_stride
33
+ self.window_size = window_size
34
+
35
+ self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
36
+ for idx, i in enumerate(self.stride_counters[1:]):
37
+ self.stride_counters[idx+1] = i - self.stride_counters[idx]
38
+ self.current_stride_counter = self.stride_counters[0] - 1
39
+
40
+
41
+ def add_datapoints(self, inputs_float):
42
+ res = []
43
+ for inp in inputs_float:
44
+ result = self.add_datapoint(inp)
45
+ if result is not None:
46
+ res.append(result)
47
+ return res
48
+
49
+
50
+ def add_datapoint(self, input_float):
51
+ input_float = input_float[self.channel-1]
52
+ result = None
53
+ self.buffer.append(input_float)
54
+ if len(self.buffer) > self.window_size:
55
+ self.buffer = self.buffer[1:]
56
+ self.current_stride_counter += 1
57
+ if self.current_stride_counter == self.stride_counter[self.interpreter_counter]:
58
+ result = self.call_model(self.interpreter_counter, self.buffer)
59
+ self.interpreter_counter += 1
60
+ self.interpreter_counter %= self.num_model_parallel
61
+ self.current_stride_counter = 0
62
+ return result
63
+
64
+
65
+
66
+ def call_model(self, idx, input_float=None):
67
+ if input_float is None:
68
+ # For debuggin purposes
69
+ input_shape = input_details[0]['shape']
70
+ input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
71
+ else:
72
+ # Convert float input to Int
73
+ input_scale, input_zero_point = input_details[0]["quantization"]
74
+ input = np.asarray(input_float) / input_scale + input_zero_point
75
+ input = input.astype(input_details[0]["dtype"])
76
+
77
+ interpreter.set_tensor(input_details[0]['index'], input)
78
+ if self.verbose:
79
+ start_time = time.time()
80
+
81
+ interpreter.invoke()
82
+
83
+ if self.verbose:
84
+ end_time = time.time()
85
+
86
+ output = interpreter.get_tensor(output_details[0]['index'])
87
+ output_scale, output_zero_point = input_details[0]["quantization"]
88
+ output = float(output - output_zero_point) * output_scale
89
+
90
+ if self.verbose:
91
+ print(f"Computed output {output} in {end_time - start_time} seconds")
92
+
93
+ return output
94
+
95
+
portiloop/notebooks/tests.ipynb CHANGED
@@ -2,16 +2,47 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "source": [
12
  "from portiloop.capture import Capture\n",
 
13
  "\n",
14
- "cap = Capture()"
15
  ]
16
  }
17
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "7b2fc5da",
7
  "metadata": {
8
  "scrolled": false
9
  },
10
+ "outputs": [
11
+ {
12
+ "data": {
13
+ "application/vnd.jupyter.widget-view+json": {
14
+ "model_id": "910f8e489b6341119f4d6e17a5b2aedc",
15
+ "version_major": 2,
16
+ "version_minor": 0
17
+ },
18
+ "text/plain": [
19
+ "VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
20
+ ]
21
+ },
22
+ "metadata": {},
23
+ "output_type": "display_data"
24
+ },
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Process Process-1:\n",
30
+ "Traceback (most recent call last):\n",
31
+ " File \"/usr/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
32
+ " self.run()\n",
33
+ " File \"/usr/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
34
+ " self._target(*self._args, **self._kwargs)\n",
35
+ " File \"/home/mendel/software/portiloop-software/portiloop/capture.py\", line 325, in _capture_process\n",
36
+ " assert data == [0x3E], \"The communication with the ADS cannot be established.\"\n",
37
+ "AssertionError: The communication with the ADS cannot be established.\n"
38
+ ]
39
+ }
40
+ ],
41
  "source": [
42
  "from portiloop.capture import Capture\n",
43
+ "from portiloop.inference import QuantizedModelForInference\n",
44
  "\n",
45
+ "cap = Capture(QuantizedModelForInference)"
46
  ]
47
  }
48
  ],
setup.py CHANGED
@@ -14,6 +14,7 @@ setup(
14
  'python-periphery',
15
  'spidev',
16
  'pylsl-coral',
17
- 'scipy'
 
18
  ]
19
  )
 
14
  'python-periphery',
15
  'spidev',
16
  'pylsl-coral',
17
+ 'scipy',
18
+ 'pycoral'
19
  ]
20
  )