invincible-jha's picture
Upload 5 files
a2fd99a verified
import numpy as np
import plotly.graph_objects as go
import mne
from typing import Dict, Optional, Tuple
import plotly.express as px
import networkx as nx
class BrainMapper:
def __init__(self):
self.montage = mne.channels.make_standard_montage('standard_1020')
self._initialize_coordinates()
def _initialize_coordinates(self):
"""Initialize electrode coordinates from standard montage"""
pos = self.montage.get_positions()
self.coords = pos['ch_pos']
# Extract x, y, z coordinates
self.ch_names = list(self.coords.keys())
self.x_coords = np.array([self.coords[ch][0] for ch in self.ch_names])
self.y_coords = np.array([self.coords[ch][1] for ch in self.ch_names])
self.z_coords = np.array([self.coords[ch][2] for ch in self.ch_names])
def create_visualization(self, features: Dict, map_type: str = "2D Topographic") -> go.Figure:
"""Create brain visualization based on the specified type"""
if map_type == "2D Topographic":
return self._create_topographic_map(features)
elif map_type == "3D Surface":
return self._create_3d_surface(features)
elif map_type == "Connectivity":
return self._create_connectivity_map(features)
else:
raise ValueError(f"Unsupported map type: {map_type}")
def _create_topographic_map(self, features: Dict) -> go.Figure:
"""Create 2D topographic map of brain activity"""
# Extract band powers for visualization
band_powers = features['band_powers']
# Create figure with subplots for each frequency band
fig = go.Figure()
for band_name, powers in band_powers.items():
# Create interpolated grid
xi = np.linspace(min(self.x_coords), max(self.x_coords), 100)
yi = np.linspace(min(self.y_coords), max(self.y_coords), 100)
xi, yi = np.meshgrid(xi, yi)
# Add contour plot for each band
fig.add_trace(go.Contour(
x=xi[0],
y=yi[:, 0],
z=powers.reshape(xi.shape),
name=band_name,
colorscale='Viridis',
showscale=True,
visible=(band_name == 'alpha') # Show alpha band by default
))
# Add scatter plot for electrode positions
fig.add_trace(go.Scatter(
x=self.x_coords,
y=self.y_coords,
mode='markers+text',
text=self.ch_names,
textposition="top center",
name='Electrodes',
marker=dict(size=10, color='black'),
visible=(band_name == 'alpha')
))
# Update layout
fig.update_layout(
title="Brain Activity Topographic Map",
xaxis_title="X Position",
yaxis_title="Y Position",
showlegend=True,
updatemenus=[{
'buttons': [
{'label': band,
'method': 'update',
'args': [{'visible': [i == j for i in range(len(band_powers)*2) for _ in range(2)]}]}
for j, band in enumerate(band_powers.keys())
],
'direction': 'down',
'showactive': True,
}]
)
return fig
def _create_3d_surface(self, features: Dict) -> go.Figure:
"""Create 3D surface plot of brain activity"""
# Create 3D surface using electrode positions
fig = go.Figure()
# Add surface plot
fig.add_trace(go.Surface(
x=self.x_coords.reshape(-1, 1),
y=self.y_coords.reshape(-1, 1),
z=features['statistics']['mean'].reshape(-1, 1),
colorscale='Viridis',
name='Brain Activity'
))
# Add scatter plot for electrode positions
fig.add_trace(go.Scatter3d(
x=self.x_coords,
y=self.y_coords,
z=self.z_coords,
mode='markers+text',
text=self.ch_names,
marker=dict(size=5, color='red'),
name='Electrodes'
))
# Update layout
fig.update_layout(
title="3D Brain Activity Surface",
scene=dict(
xaxis_title="X Position",
yaxis_title="Y Position",
zaxis_title="Activity Level",
camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
eye=dict(x=1.5, y=1.5, z=1.5)
)
)
)
return fig
def _create_connectivity_map(self, features: Dict) -> go.Figure:
"""Create brain connectivity visualization"""
# Extract connectivity matrix
connectivity = features['connectivity']['correlation']
# Create graph
G = nx.from_numpy_array(connectivity)
pos = nx.spring_layout(G, k=1, iterations=50)
# Create edge trace
edge_x = []
edge_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines')
# Create node trace
node_x = []
node_y = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=self.ch_names,
marker=dict(
showscale=True,
colorscale='YlOrRd',
size=10,
colorbar=dict(
thickness=15,
title='Node Connections',
xanchor='left',
titleside='right'
)
)
)
# Color node points by the number of connections
node_adjacencies = []
for node, adjacencies in enumerate(G.adjacency()):
node_adjacencies.append(len(adjacencies[1]))
node_trace.marker.color = node_adjacencies
# Create figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title='Brain Connectivity Network',
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
))
return fig