|
import numpy as np |
|
import plotly.graph_objects as go |
|
import mne |
|
from typing import Dict, Optional, Tuple |
|
import plotly.express as px |
|
import networkx as nx |
|
|
|
class BrainMapper: |
|
def __init__(self): |
|
self.montage = mne.channels.make_standard_montage('standard_1020') |
|
self._initialize_coordinates() |
|
|
|
def _initialize_coordinates(self): |
|
"""Initialize electrode coordinates from standard montage""" |
|
pos = self.montage.get_positions() |
|
self.coords = pos['ch_pos'] |
|
|
|
|
|
self.ch_names = list(self.coords.keys()) |
|
self.x_coords = np.array([self.coords[ch][0] for ch in self.ch_names]) |
|
self.y_coords = np.array([self.coords[ch][1] for ch in self.ch_names]) |
|
self.z_coords = np.array([self.coords[ch][2] for ch in self.ch_names]) |
|
|
|
def create_visualization(self, features: Dict, map_type: str = "2D Topographic") -> go.Figure: |
|
"""Create brain visualization based on the specified type""" |
|
if map_type == "2D Topographic": |
|
return self._create_topographic_map(features) |
|
elif map_type == "3D Surface": |
|
return self._create_3d_surface(features) |
|
elif map_type == "Connectivity": |
|
return self._create_connectivity_map(features) |
|
else: |
|
raise ValueError(f"Unsupported map type: {map_type}") |
|
|
|
def _create_topographic_map(self, features: Dict) -> go.Figure: |
|
"""Create 2D topographic map of brain activity""" |
|
|
|
band_powers = features['band_powers'] |
|
|
|
|
|
fig = go.Figure() |
|
|
|
for band_name, powers in band_powers.items(): |
|
|
|
xi = np.linspace(min(self.x_coords), max(self.x_coords), 100) |
|
yi = np.linspace(min(self.y_coords), max(self.y_coords), 100) |
|
xi, yi = np.meshgrid(xi, yi) |
|
|
|
|
|
fig.add_trace(go.Contour( |
|
x=xi[0], |
|
y=yi[:, 0], |
|
z=powers.reshape(xi.shape), |
|
name=band_name, |
|
colorscale='Viridis', |
|
showscale=True, |
|
visible=(band_name == 'alpha') |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=self.x_coords, |
|
y=self.y_coords, |
|
mode='markers+text', |
|
text=self.ch_names, |
|
textposition="top center", |
|
name='Electrodes', |
|
marker=dict(size=10, color='black'), |
|
visible=(band_name == 'alpha') |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title="Brain Activity Topographic Map", |
|
xaxis_title="X Position", |
|
yaxis_title="Y Position", |
|
showlegend=True, |
|
updatemenus=[{ |
|
'buttons': [ |
|
{'label': band, |
|
'method': 'update', |
|
'args': [{'visible': [i == j for i in range(len(band_powers)*2) for _ in range(2)]}]} |
|
for j, band in enumerate(band_powers.keys()) |
|
], |
|
'direction': 'down', |
|
'showactive': True, |
|
}] |
|
) |
|
|
|
return fig |
|
|
|
def _create_3d_surface(self, features: Dict) -> go.Figure: |
|
"""Create 3D surface plot of brain activity""" |
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Surface( |
|
x=self.x_coords.reshape(-1, 1), |
|
y=self.y_coords.reshape(-1, 1), |
|
z=features['statistics']['mean'].reshape(-1, 1), |
|
colorscale='Viridis', |
|
name='Brain Activity' |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter3d( |
|
x=self.x_coords, |
|
y=self.y_coords, |
|
z=self.z_coords, |
|
mode='markers+text', |
|
text=self.ch_names, |
|
marker=dict(size=5, color='red'), |
|
name='Electrodes' |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title="3D Brain Activity Surface", |
|
scene=dict( |
|
xaxis_title="X Position", |
|
yaxis_title="Y Position", |
|
zaxis_title="Activity Level", |
|
camera=dict( |
|
up=dict(x=0, y=0, z=1), |
|
center=dict(x=0, y=0, z=0), |
|
eye=dict(x=1.5, y=1.5, z=1.5) |
|
) |
|
) |
|
) |
|
|
|
return fig |
|
|
|
def _create_connectivity_map(self, features: Dict) -> go.Figure: |
|
"""Create brain connectivity visualization""" |
|
|
|
connectivity = features['connectivity']['correlation'] |
|
|
|
|
|
G = nx.from_numpy_array(connectivity) |
|
pos = nx.spring_layout(G, k=1, iterations=50) |
|
|
|
|
|
edge_x = [] |
|
edge_y = [] |
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
edge_trace = go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=0.5, color='#888'), |
|
hoverinfo='none', |
|
mode='lines') |
|
|
|
|
|
node_x = [] |
|
node_y = [] |
|
for node in G.nodes(): |
|
x, y = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
|
|
node_trace = go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers+text', |
|
hoverinfo='text', |
|
text=self.ch_names, |
|
marker=dict( |
|
showscale=True, |
|
colorscale='YlOrRd', |
|
size=10, |
|
colorbar=dict( |
|
thickness=15, |
|
title='Node Connections', |
|
xanchor='left', |
|
titleside='right' |
|
) |
|
) |
|
) |
|
|
|
|
|
node_adjacencies = [] |
|
for node, adjacencies in enumerate(G.adjacency()): |
|
node_adjacencies.append(len(adjacencies[1])) |
|
node_trace.marker.color = node_adjacencies |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout( |
|
title='Brain Connectivity Network', |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) |
|
)) |
|
|
|
return fig |