Milo Sobral commited on
Commit
111f264
·
1 Parent(s): 98fb56f

Finished last few changes

Browse files
portiloop/src/demo/demo.py CHANGED
@@ -8,67 +8,46 @@ def on_upload_file(file):
8
  if file.name.split(".")[-1] != "xdf":
9
  raise gr.Error("Please upload a .xdf file.")
10
  else:
11
- yield f"File {file.name} successfully uploaded!"
12
 
13
 
14
- with gr.Blocks(title="Portiloop") as demo:
15
- gr.Markdown("# Portiloop Demo")
16
- gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
17
- gr.Markdown("Upload your CSV file and click **Run Inference** to start the processing...")
 
18
 
19
- with gr.Row():
20
- xdf_file = gr.UploadButton(label="XDF File", type="file")
 
21
 
22
- # Offline Filtering (Boolean)
23
- offline_filtering = gr.Checkbox(label="Offline Filtering (On/Off)", value=True)
24
- # Online Filtering (Boolean)
25
- online_filtering = gr.Checkbox(label="Online Filtering (On/Off)", value=True)
26
- # Lacourse's Method (Boolean)
27
- lacourse = gr.Checkbox(label="Lacourse Detection (On/Off)", value=True)
28
- # Wamsley's Method (Boolean)
29
- wamsley = gr.Checkbox(label="Wamsley Detection (On/Off)", value=True)
30
- # Online Detection (Boolean)
31
- online_detection = gr.Checkbox(label="Online Detection (On/Off)", value=True)
32
 
33
- # Threshold value
34
- threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
35
- # Detection Channel
36
- detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
37
- # Frequency
38
- freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
39
 
40
- # Output elements
41
- update_text = gr.Textbox(value="Upload an XDF File and click 'Run Inference'...", label="Status", interactive=False)
42
- output_plot = gr.Plot()
43
- output_array = gr.File(label="Output CSV File")
44
- xdf_file.upload(fn=on_upload_file, inputs=[xdf_file], outputs=[update_text])
 
45
 
46
- def clear():
47
- output_plot.clear()
48
- output_array.clear()
49
- update_text.clear()
50
- xdf_file.clear()
51
 
52
- # Row containing all buttons:
53
- with gr.Row():
54
- # Run inference button
55
- run_inference = gr.Button(
56
- value="Run Inference")
57
- # Reset button
58
- reset = gr.Button(value="Reset", variant="secondary", on_click=clear,)
59
- run_inference.click(
60
- fn=run_offline,
61
- inputs=[
62
- xdf_file,
63
- offline_filtering,
64
- online_filtering,
65
- online_detection,
66
- lacourse,
67
- wamsley,
68
- threshold,
69
- detect_channel,
70
- freq],
71
- outputs=[output_plot, output_array, update_text])
72
 
73
- demo.queue()
74
- demo.launch(share=False)
 
 
 
 
8
  if file.name.split(".")[-1] != "xdf":
9
  raise gr.Error("Please upload a .xdf file.")
10
  else:
11
+ return file.name
12
 
13
 
14
+ def main():
15
+ with gr.Blocks(title="Portiloop") as demo:
16
+ gr.Markdown("# Portiloop Demo")
17
+ gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
18
+ gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
19
 
20
+ with gr.Row():
21
+ xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
22
+ xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
23
 
24
+ xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
 
 
 
 
 
 
 
 
 
25
 
26
+ # Make a checkbox group for the options
27
+ detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
 
 
 
 
28
 
29
+ # Threshold value
30
+ threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
31
+ # Detection Channel
32
+ detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
33
+ # Frequency
34
+ freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
35
 
36
+ output_array = gr.File(label="Output CSV File")
 
 
 
 
37
 
38
+ run_inference = gr.Button(value="Run Inference")
39
+ run_inference.click(
40
+ fn=run_offline,
41
+ inputs=[
42
+ xdf_file_static,
43
+ detect_filter,
44
+ threshold,
45
+ detect_channel,
46
+ freq],
47
+ outputs=[output_array])
 
 
 
 
 
 
 
 
 
 
48
 
49
+ demo.queue()
50
+ demo.launch(share=True)
51
+
52
+ if __name__ == "__main__":
53
+ main()
portiloop/src/demo/offline.py CHANGED
@@ -7,9 +7,14 @@ from portiloop.src.demo.utils import xdf2array, offline_detect, offline_filter,
7
  import gradio as gr
8
 
9
 
10
- def run_offline(xdf_file, offline_filtering, online_filtering, online_detection, lacourse, wamsley, threshold, channel_num, freq):
 
 
 
 
 
 
11
 
12
- print("Starting offline processing...")
13
  # Make sure the inputs make sense:
14
  if not offline_filtering and (lacourse or wamsley):
15
  raise gr.Error("You can't use the offline detection methods without offline filtering.")
@@ -17,17 +22,17 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
17
  if not online_filtering and online_detection:
18
  raise gr.Error("You can't use the online detection without online filtering.")
19
 
 
 
 
20
  freq = int(freq)
21
 
22
  # Read the xdf file to a numpy array
23
  print("Loading xdf file...")
24
- yield None, None, "Loading xdf file..."
25
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
26
- print(data_whole.shape)
27
  # Do the offline filtering of the data
28
- print("Filtering offline...")
29
- yield None, None, "Filtering offline..."
30
  if offline_filtering:
 
31
  offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
32
  # Expand the dimension of the filtered data to match the shape of the other columns
33
  offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
@@ -35,9 +40,8 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
35
  columns.append("offline_filtered_signal")
36
 
37
  # Do Wamsley's method
38
- print("Running Wamsley detection...")
39
- yield None, None, "Running Wamsley detection..."
40
  if wamsley:
 
41
  wamsley_data = offline_detect("Wamsley", \
42
  data_whole[:, columns.index("offline_filtered_signal")],\
43
  data_whole[:, columns.index("time_stamps")],\
@@ -47,9 +51,8 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
47
  columns.append("wamsley_spindles")
48
 
49
  # Do Lacourse's method
50
- print("Running Lacourse detection...")
51
- yield None, None, "Running Lacourse detection..."
52
  if lacourse:
 
53
  lacourse_data = offline_detect("Lacourse", \
54
  data_whole[:, columns.index("offline_filtered_signal")],\
55
  data_whole[:, columns.index("time_stamps")],\
@@ -70,10 +73,9 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
70
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
71
  stimulator = OfflineSleepSpindleRealTimeStimulator()
72
 
73
- print("Running online filtering and detection...")
74
- yield None, None, "Running online filtering and detection..."
75
  if online_filtering or online_detection:
76
- # Plotting variables
 
77
  points = []
78
  online_activations = []
79
 
@@ -97,25 +99,6 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
97
  online_activations.append(1)
98
  else:
99
  online_activations.append(0)
100
-
101
- # Function to return a list of all indexes where activations have happened
102
- def get_activations(activations):
103
- return [i for i, x in enumerate(activations) if x == 1]
104
-
105
- # Plot the data
106
- if index % (10 * freq) == 0 and index >= (10 * freq):
107
- plt.close()
108
- fig = plt.figure(figsize=(20, 10))
109
- plt.clf()
110
- plt.plot(np.linspace(0, 10, num=freq*10), points[-10 * freq:], label="Data")
111
- # Draw vertical lines for activations
112
- for index in get_activations(online_activations[-10 * freq:]):
113
- plt.axvline(x=index / freq, color='r', label="Portiloop Stimulation")
114
- # Add axis titles and legend
115
- plt.legend()
116
- plt.xlabel("Time (s)")
117
- plt.ylabel("Amplitude")
118
- yield fig, None, f"Running online filtering and detection {index}/{len(data)}..."
119
 
120
  if online_filtering:
121
  online_filtered = np.array(points)
@@ -134,4 +117,4 @@ def run_offline(xdf_file, offline_filtering, online_filtering, online_detection,
134
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
135
 
136
  print("Done!")
137
- yield None, "output.csv", "Done!"
 
7
  import gradio as gr
8
 
9
 
10
+ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
11
+ # Get the options from the checkbox group
12
+ offline_filtering = 0 in detect_filter_opts
13
+ lacourse = 1 in detect_filter_opts
14
+ wamsley = 2 in detect_filter_opts
15
+ online_filtering = 3 in detect_filter_opts
16
+ online_detection = 4 in detect_filter_opts
17
 
 
18
  # Make sure the inputs make sense:
19
  if not offline_filtering and (lacourse or wamsley):
20
  raise gr.Error("You can't use the offline detection methods without offline filtering.")
 
22
  if not online_filtering and online_detection:
23
  raise gr.Error("You can't use the online detection without online filtering.")
24
 
25
+ if xdf_file is None:
26
+ raise gr.Error("Please upload a .xdf file.")
27
+
28
  freq = int(freq)
29
 
30
  # Read the xdf file to a numpy array
31
  print("Loading xdf file...")
 
32
  data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
 
33
  # Do the offline filtering of the data
 
 
34
  if offline_filtering:
35
+ print("Filtering offline...")
36
  offline_filtered_data = offline_filter(data_whole[:, columns.index("raw_signal")], freq)
37
  # Expand the dimension of the filtered data to match the shape of the other columns
38
  offline_filtered_data = np.expand_dims(offline_filtered_data, axis=1)
 
40
  columns.append("offline_filtered_signal")
41
 
42
  # Do Wamsley's method
 
 
43
  if wamsley:
44
+ print("Running Wamsley detection...")
45
  wamsley_data = offline_detect("Wamsley", \
46
  data_whole[:, columns.index("offline_filtered_signal")],\
47
  data_whole[:, columns.index("time_stamps")],\
 
51
  columns.append("wamsley_spindles")
52
 
53
  # Do Lacourse's method
 
 
54
  if lacourse:
55
+ print("Running Lacourse detection...")
56
  lacourse_data = offline_detect("Lacourse", \
57
  data_whole[:, columns.index("offline_filtered_signal")],\
58
  data_whole[:, columns.index("time_stamps")],\
 
73
  detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
74
  stimulator = OfflineSleepSpindleRealTimeStimulator()
75
 
 
 
76
  if online_filtering or online_detection:
77
+ print("Running online filtering and detection...")
78
+
79
  points = []
80
  online_activations = []
81
 
 
99
  online_activations.append(1)
100
  else:
101
  online_activations.append(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if online_filtering:
104
  online_filtered = np.array(points)
 
117
  np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
118
 
119
  print("Done!")
120
+ return "output.csv"