invincible-jha commited on
Commit
a2fd99a
1 Parent(s): a161463

Upload 5 files

Browse files
modules/brain_mapper.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+ import mne
4
+ from typing import Dict, Optional, Tuple
5
+ import plotly.express as px
6
+ import networkx as nx
7
+
8
+ class BrainMapper:
9
+ def __init__(self):
10
+ self.montage = mne.channels.make_standard_montage('standard_1020')
11
+ self._initialize_coordinates()
12
+
13
+ def _initialize_coordinates(self):
14
+ """Initialize electrode coordinates from standard montage"""
15
+ pos = self.montage.get_positions()
16
+ self.coords = pos['ch_pos']
17
+
18
+ # Extract x, y, z coordinates
19
+ self.ch_names = list(self.coords.keys())
20
+ self.x_coords = np.array([self.coords[ch][0] for ch in self.ch_names])
21
+ self.y_coords = np.array([self.coords[ch][1] for ch in self.ch_names])
22
+ self.z_coords = np.array([self.coords[ch][2] for ch in self.ch_names])
23
+
24
+ def create_visualization(self, features: Dict, map_type: str = "2D Topographic") -> go.Figure:
25
+ """Create brain visualization based on the specified type"""
26
+ if map_type == "2D Topographic":
27
+ return self._create_topographic_map(features)
28
+ elif map_type == "3D Surface":
29
+ return self._create_3d_surface(features)
30
+ elif map_type == "Connectivity":
31
+ return self._create_connectivity_map(features)
32
+ else:
33
+ raise ValueError(f"Unsupported map type: {map_type}")
34
+
35
+ def _create_topographic_map(self, features: Dict) -> go.Figure:
36
+ """Create 2D topographic map of brain activity"""
37
+ # Extract band powers for visualization
38
+ band_powers = features['band_powers']
39
+
40
+ # Create figure with subplots for each frequency band
41
+ fig = go.Figure()
42
+
43
+ for band_name, powers in band_powers.items():
44
+ # Create interpolated grid
45
+ xi = np.linspace(min(self.x_coords), max(self.x_coords), 100)
46
+ yi = np.linspace(min(self.y_coords), max(self.y_coords), 100)
47
+ xi, yi = np.meshgrid(xi, yi)
48
+
49
+ # Add contour plot for each band
50
+ fig.add_trace(go.Contour(
51
+ x=xi[0],
52
+ y=yi[:, 0],
53
+ z=powers.reshape(xi.shape),
54
+ name=band_name,
55
+ colorscale='Viridis',
56
+ showscale=True,
57
+ visible=(band_name == 'alpha') # Show alpha band by default
58
+ ))
59
+
60
+ # Add scatter plot for electrode positions
61
+ fig.add_trace(go.Scatter(
62
+ x=self.x_coords,
63
+ y=self.y_coords,
64
+ mode='markers+text',
65
+ text=self.ch_names,
66
+ textposition="top center",
67
+ name='Electrodes',
68
+ marker=dict(size=10, color='black'),
69
+ visible=(band_name == 'alpha')
70
+ ))
71
+
72
+ # Update layout
73
+ fig.update_layout(
74
+ title="Brain Activity Topographic Map",
75
+ xaxis_title="X Position",
76
+ yaxis_title="Y Position",
77
+ showlegend=True,
78
+ updatemenus=[{
79
+ 'buttons': [
80
+ {'label': band,
81
+ 'method': 'update',
82
+ 'args': [{'visible': [i == j for i in range(len(band_powers)*2) for _ in range(2)]}]}
83
+ for j, band in enumerate(band_powers.keys())
84
+ ],
85
+ 'direction': 'down',
86
+ 'showactive': True,
87
+ }]
88
+ )
89
+
90
+ return fig
91
+
92
+ def _create_3d_surface(self, features: Dict) -> go.Figure:
93
+ """Create 3D surface plot of brain activity"""
94
+ # Create 3D surface using electrode positions
95
+ fig = go.Figure()
96
+
97
+ # Add surface plot
98
+ fig.add_trace(go.Surface(
99
+ x=self.x_coords.reshape(-1, 1),
100
+ y=self.y_coords.reshape(-1, 1),
101
+ z=features['statistics']['mean'].reshape(-1, 1),
102
+ colorscale='Viridis',
103
+ name='Brain Activity'
104
+ ))
105
+
106
+ # Add scatter plot for electrode positions
107
+ fig.add_trace(go.Scatter3d(
108
+ x=self.x_coords,
109
+ y=self.y_coords,
110
+ z=self.z_coords,
111
+ mode='markers+text',
112
+ text=self.ch_names,
113
+ marker=dict(size=5, color='red'),
114
+ name='Electrodes'
115
+ ))
116
+
117
+ # Update layout
118
+ fig.update_layout(
119
+ title="3D Brain Activity Surface",
120
+ scene=dict(
121
+ xaxis_title="X Position",
122
+ yaxis_title="Y Position",
123
+ zaxis_title="Activity Level",
124
+ camera=dict(
125
+ up=dict(x=0, y=0, z=1),
126
+ center=dict(x=0, y=0, z=0),
127
+ eye=dict(x=1.5, y=1.5, z=1.5)
128
+ )
129
+ )
130
+ )
131
+
132
+ return fig
133
+
134
+ def _create_connectivity_map(self, features: Dict) -> go.Figure:
135
+ """Create brain connectivity visualization"""
136
+ # Extract connectivity matrix
137
+ connectivity = features['connectivity']['correlation']
138
+
139
+ # Create graph
140
+ G = nx.from_numpy_array(connectivity)
141
+ pos = nx.spring_layout(G, k=1, iterations=50)
142
+
143
+ # Create edge trace
144
+ edge_x = []
145
+ edge_y = []
146
+ for edge in G.edges():
147
+ x0, y0 = pos[edge[0]]
148
+ x1, y1 = pos[edge[1]]
149
+ edge_x.extend([x0, x1, None])
150
+ edge_y.extend([y0, y1, None])
151
+
152
+ edge_trace = go.Scatter(
153
+ x=edge_x, y=edge_y,
154
+ line=dict(width=0.5, color='#888'),
155
+ hoverinfo='none',
156
+ mode='lines')
157
+
158
+ # Create node trace
159
+ node_x = []
160
+ node_y = []
161
+ for node in G.nodes():
162
+ x, y = pos[node]
163
+ node_x.append(x)
164
+ node_y.append(y)
165
+
166
+ node_trace = go.Scatter(
167
+ x=node_x, y=node_y,
168
+ mode='markers+text',
169
+ hoverinfo='text',
170
+ text=self.ch_names,
171
+ marker=dict(
172
+ showscale=True,
173
+ colorscale='YlOrRd',
174
+ size=10,
175
+ colorbar=dict(
176
+ thickness=15,
177
+ title='Node Connections',
178
+ xanchor='left',
179
+ titleside='right'
180
+ )
181
+ )
182
+ )
183
+
184
+ # Color node points by the number of connections
185
+ node_adjacencies = []
186
+ for node, adjacencies in enumerate(G.adjacency()):
187
+ node_adjacencies.append(len(adjacencies[1]))
188
+ node_trace.marker.color = node_adjacencies
189
+
190
+ # Create figure
191
+ fig = go.Figure(data=[edge_trace, node_trace],
192
+ layout=go.Layout(
193
+ title='Brain Connectivity Network',
194
+ showlegend=False,
195
+ hovermode='closest',
196
+ margin=dict(b=20,l=5,r=5,t=40),
197
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
198
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
199
+ ))
200
+
201
+ return fig
modules/clinical_analyzer.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+ from typing import Dict, List, Optional
4
+ import torch
5
+ import json
6
+
7
+ class ClinicalAnalyzer:
8
+ def __init__(self):
9
+ self.initialize_models()
10
+ self.condition_patterns = self._load_condition_patterns()
11
+
12
+ def initialize_models(self):
13
+ """Initialize transformer models for clinical analysis"""
14
+ try:
15
+ # Clinical text analysis model
16
+ self.clinical_tokenizer = AutoTokenizer.from_pretrained(
17
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
18
+ )
19
+ self.clinical_model = AutoModelForSequenceClassification.from_pretrained(
20
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
21
+ )
22
+
23
+ # Mental health assessment pipeline
24
+ self.mental_health_pipeline = pipeline(
25
+ "text-classification",
26
+ model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
27
+ tokenizer="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
28
+ )
29
+ except Exception as e:
30
+ print(f"Error initializing models: {str(e)}")
31
+ # Fallback to rule-based analysis if models fail to load
32
+ self.clinical_model = None
33
+ self.clinical_tokenizer = None
34
+ self.mental_health_pipeline = None
35
+
36
+ def _load_condition_patterns(self) -> Dict:
37
+ """Load predefined patterns for mental health conditions"""
38
+ return {
39
+ 'depression': {
40
+ 'eeg_patterns': {
41
+ 'alpha_asymmetry': True,
42
+ 'theta_increase': True,
43
+ 'beta_decrease': True
44
+ },
45
+ 'keywords': [
46
+ 'depressed mood', 'loss of interest', 'fatigue',
47
+ 'sleep disturbance', 'concentration problems'
48
+ ]
49
+ },
50
+ 'anxiety': {
51
+ 'eeg_patterns': {
52
+ 'beta_increase': True,
53
+ 'alpha_decrease': True,
54
+ 'high_coherence': True
55
+ },
56
+ 'keywords': [
57
+ 'anxiety', 'worry', 'restlessness', 'tension',
58
+ 'panic', 'nervousness'
59
+ ]
60
+ },
61
+ 'ptsd': {
62
+ 'eeg_patterns': {
63
+ 'alpha_suppression': True,
64
+ 'theta_increase': True,
65
+ 'beta_asymmetry': True
66
+ },
67
+ 'keywords': [
68
+ 'trauma', 'flashbacks', 'nightmares', 'avoidance',
69
+ 'hypervigilance', 'startle response'
70
+ ]
71
+ }
72
+ }
73
+
74
+ def analyze(self, features: Dict, clinical_notes: str) -> Dict:
75
+ """Perform comprehensive clinical analysis"""
76
+ analysis_results = {
77
+ 'eeg_analysis': self._analyze_eeg_patterns(features),
78
+ 'text_analysis': self._analyze_clinical_text(clinical_notes),
79
+ 'condition_probabilities': self._calculate_condition_probabilities(
80
+ features, clinical_notes
81
+ ),
82
+ 'severity_assessment': self._assess_severity(features, clinical_notes),
83
+ 'recommendations': self._generate_recommendations(features, clinical_notes)
84
+ }
85
+ return analysis_results
86
+
87
+ def _analyze_eeg_patterns(self, features: Dict) -> Dict:
88
+ """Analyze EEG patterns for clinical significance"""
89
+ eeg_analysis = {}
90
+
91
+ # Analyze band power distributions
92
+ band_powers = features['band_powers']
93
+ eeg_analysis['band_power_analysis'] = {
94
+ band: {
95
+ 'mean': float(np.mean(powers)),
96
+ 'std': float(np.std(powers)),
97
+ 'clinical_significance': self._assess_band_significance(band, powers)
98
+ }
99
+ for band, powers in band_powers.items()
100
+ }
101
+
102
+ # Analyze connectivity patterns
103
+ connectivity = features['connectivity']
104
+ eeg_analysis['connectivity_analysis'] = {
105
+ 'global_connectivity': float(np.mean(connectivity['correlation'])),
106
+ 'asymmetry_index': self._calculate_asymmetry_index(features)
107
+ }
108
+
109
+ return eeg_analysis
110
+
111
+ def _analyze_clinical_text(self, clinical_notes: str) -> Dict:
112
+ """Analyze clinical notes using NLP"""
113
+ if not clinical_notes:
114
+ return {'error': 'No clinical notes provided'}
115
+
116
+ try:
117
+ if self.mental_health_pipeline:
118
+ # Use transformer model for analysis
119
+ results = self.mental_health_pipeline(clinical_notes)
120
+ text_analysis = {
121
+ 'sentiment': results[0]['label'],
122
+ 'confidence': float(results[0]['score'])
123
+ }
124
+ else:
125
+ # Fallback to keyword-based analysis
126
+ text_analysis = self._keyword_based_analysis(clinical_notes)
127
+
128
+ # Extract symptoms and severity
129
+ text_analysis['identified_symptoms'] = self._extract_symptoms(clinical_notes)
130
+ text_analysis['risk_factors'] = self._identify_risk_factors(clinical_notes)
131
+
132
+ return text_analysis
133
+
134
+ except Exception as e:
135
+ return {'error': f'Text analysis failed: {str(e)}'}
136
+
137
+ def _calculate_condition_probabilities(
138
+ self, features: Dict, clinical_notes: str
139
+ ) -> Dict:
140
+ """Calculate probabilities for different mental health conditions"""
141
+ probabilities = {}
142
+
143
+ for condition, patterns in self.condition_patterns.items():
144
+ # Calculate EEG pattern match score
145
+ eeg_score = self._calculate_eeg_pattern_match(
146
+ features, patterns['eeg_patterns']
147
+ )
148
+
149
+ # Calculate text pattern match score
150
+ text_score = self._calculate_text_pattern_match(
151
+ clinical_notes, patterns['keywords']
152
+ )
153
+
154
+ # Combine scores with weighted average
155
+ combined_score = 0.6 * eeg_score + 0.4 * text_score
156
+ probabilities[condition] = float(combined_score)
157
+
158
+ return probabilities
159
+
160
+ def _assess_severity(self, features: Dict, clinical_notes: str) -> Dict:
161
+ """Assess the severity of identified conditions"""
162
+ severity = {
163
+ 'overall_severity': self._calculate_overall_severity(features, clinical_notes),
164
+ 'domain_severity': {
165
+ 'cognitive': self._assess_cognitive_severity(features),
166
+ 'emotional': self._assess_emotional_severity(features, clinical_notes),
167
+ 'behavioral': self._assess_behavioral_severity(clinical_notes)
168
+ },
169
+ 'risk_level': self._assess_risk_level(features, clinical_notes)
170
+ }
171
+ return severity
172
+
173
+ def _generate_recommendations(self, features: Dict, clinical_notes: str) -> List[str]:
174
+ """Generate clinical recommendations based on analysis"""
175
+ recommendations = []
176
+
177
+ # Analyze severity and conditions
178
+ severity = self._assess_severity(features, clinical_notes)
179
+ conditions = self._calculate_condition_probabilities(features, clinical_notes)
180
+
181
+ # Generate general recommendations
182
+ if severity['overall_severity'] > 0.7:
183
+ recommendations.append("Immediate clinical intervention recommended")
184
+ elif severity['overall_severity'] > 0.4:
185
+ recommendations.append("Regular clinical monitoring recommended")
186
+
187
+ # Condition-specific recommendations
188
+ for condition, probability in conditions.items():
189
+ if probability > 0.6:
190
+ recommendations.extend(
191
+ self._get_condition_specific_recommendations(condition)
192
+ )
193
+
194
+ return recommendations
195
+
196
+ def _calculate_eeg_pattern_match(self, features: Dict, patterns: Dict) -> float:
197
+ """Calculate how well EEG features match condition patterns"""
198
+ match_scores = []
199
+
200
+ for pattern, expected in patterns.items():
201
+ if pattern == 'alpha_asymmetry':
202
+ score = self._check_alpha_asymmetry(features)
203
+ elif pattern == 'beta_increase':
204
+ score = self._check_beta_increase(features)
205
+ elif pattern == 'theta_increase':
206
+ score = self._check_theta_increase(features)
207
+ else:
208
+ score = 0.5 # Default score for unknown patterns
209
+
210
+ match_scores.append(score if expected else 1 - score)
211
+
212
+ return np.mean(match_scores) if match_scores else 0.0
213
+
214
+ def _calculate_text_pattern_match(self, text: str, keywords: List[str]) -> float:
215
+ """Calculate how well clinical notes match condition keywords"""
216
+ if not text:
217
+ return 0.0
218
+
219
+ text = text.lower()
220
+ matched_keywords = sum(1 for keyword in keywords if keyword.lower() in text)
221
+ return matched_keywords / len(keywords)
222
+
223
+ def _calculate_asymmetry_index(self, features: Dict) -> float:
224
+ """Calculate brain asymmetry index from EEG features"""
225
+ try:
226
+ # Calculate alpha asymmetry between left and right hemispheres
227
+ alpha_powers = features['band_powers']['alpha']
228
+ left_channels = alpha_powers[:len(alpha_powers)//2]
229
+ right_channels = alpha_powers[len(alpha_powers)//2:]
230
+
231
+ asymmetry = np.log(np.mean(right_channels)) - np.log(np.mean(left_channels))
232
+ return float(asymmetry)
233
+ except:
234
+ return 0.0
235
+
236
+ def _assess_band_significance(self, band: str, powers: np.ndarray) -> str:
237
+ """Assess clinical significance of frequency band power"""
238
+ mean_power = np.mean(powers)
239
+ if band == 'alpha':
240
+ if mean_power < 0.3:
241
+ return "Significantly reduced alpha power"
242
+ elif mean_power > 0.7:
243
+ return "Elevated alpha power"
244
+ elif band == 'beta':
245
+ if mean_power > 0.7:
246
+ return "Elevated beta power - possible anxiety"
247
+ elif band == 'theta':
248
+ if mean_power > 0.6:
249
+ return "Elevated theta power - possible cognitive issues"
250
+
251
+ return "Within normal range"
252
+
253
+ def _get_condition_specific_recommendations(self, condition: str) -> List[str]:
254
+ """Get specific recommendations for identified conditions"""
255
+ recommendations = {
256
+ 'depression': [
257
+ "Consider cognitive behavioral therapy",
258
+ "Evaluate need for antidepressant medication",
259
+ "Recommend regular physical activity",
260
+ "Implement sleep hygiene practices"
261
+ ],
262
+ 'anxiety': [
263
+ "Consider anxiety-focused psychotherapy",
264
+ "Evaluate need for anti-anxiety medication",
265
+ "Recommend relaxation techniques",
266
+ "Practice mindfulness meditation"
267
+ ],
268
+ 'ptsd': [
269
+ "Consider trauma-focused therapy",
270
+ "Evaluate need for PTSD-specific medication",
271
+ "Implement grounding techniques",
272
+ "Develop safety and coping plans"
273
+ ]
274
+ }
275
+ return recommendations.get(condition, [])
modules/eeg_processor.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mne
2
+ import numpy as np
3
+ from scipy import signal
4
+ from typing import Dict, List, Tuple, Optional
5
+
6
+ class EEGProcessor:
7
+ def __init__(self):
8
+ self.sfreq = 250 # Default sampling frequency
9
+ self.freq_bands = {
10
+ 'delta': (0.5, 4),
11
+ 'theta': (4, 8),
12
+ 'alpha': (8, 13),
13
+ 'beta': (13, 30),
14
+ 'gamma': (30, 50)
15
+ }
16
+
17
+ def preprocess(self, raw: mne.io.Raw) -> mne.io.Raw:
18
+ """Preprocess raw EEG data"""
19
+ # Set montage if not present
20
+ if raw.get_montage() is None:
21
+ raw.set_montage('standard_1020')
22
+
23
+ # Basic preprocessing pipeline
24
+ raw_processed = raw.copy()
25
+
26
+ # Filter data
27
+ raw_processed.filter(l_freq=0.5, h_freq=50.0)
28
+ raw_processed.notch_filter(freqs=50) # Remove power line noise
29
+
30
+ # Detect and interpolate bad channels
31
+ raw_processed.interpolate_bads()
32
+
33
+ # Apply ICA for artifact removal
34
+ ica = mne.preprocessing.ICA(n_components=0.95, random_state=42)
35
+ ica.fit(raw_processed)
36
+
37
+ # Detect and remove eye blinks
38
+ eog_indices, eog_scores = ica.find_bads_eog(raw_processed)
39
+ if eog_indices:
40
+ ica.exclude = eog_indices
41
+ ica.apply(raw_processed)
42
+
43
+ return raw_processed
44
+
45
+ def extract_features(self, raw: mne.io.Raw) -> Dict:
46
+ """Extract relevant features from preprocessed EEG data"""
47
+ features = {}
48
+
49
+ # Get data and times
50
+ data, times = raw.get_data(return_times=True)
51
+
52
+ # Calculate power spectral density
53
+ psds, freqs = mne.time_frequency.psd_welch(
54
+ raw,
55
+ fmin=0.5,
56
+ fmax=50.0,
57
+ n_fft=int(raw.info['sfreq'] * 4),
58
+ n_overlap=int(raw.info['sfreq'] * 2)
59
+ )
60
+
61
+ # Extract band powers
62
+ features['band_powers'] = self._calculate_band_powers(psds, freqs)
63
+
64
+ # Calculate connectivity metrics
65
+ features['connectivity'] = self._calculate_connectivity(data)
66
+
67
+ # Extract statistical features
68
+ features['statistics'] = self._calculate_statistics(data)
69
+
70
+ return features
71
+
72
+ def _calculate_band_powers(self, psds: np.ndarray, freqs: np.ndarray) -> Dict:
73
+ """Calculate power in different frequency bands"""
74
+ band_powers = {}
75
+
76
+ for band_name, (fmin, fmax) in self.freq_bands.items():
77
+ # Find frequencies that fall within band
78
+ freq_mask = (freqs >= fmin) & (freqs <= fmax)
79
+
80
+ # Calculate average power in band
81
+ band_power = np.mean(psds[:, freq_mask], axis=1)
82
+ band_powers[band_name] = band_power
83
+
84
+ return band_powers
85
+
86
+ def _calculate_connectivity(self, data: np.ndarray) -> Dict:
87
+ """Calculate connectivity metrics between channels"""
88
+ n_channels = data.shape[0]
89
+ connectivity = {
90
+ 'correlation': np.corrcoef(data),
91
+ 'coherence': np.zeros((n_channels, n_channels))
92
+ }
93
+
94
+ # Calculate coherence between all channel pairs
95
+ for i in range(n_channels):
96
+ for j in range(i + 1, n_channels):
97
+ f, coh = signal.coherence(data[i], data[j], fs=self.sfreq)
98
+ connectivity['coherence'][i, j] = np.mean(coh)
99
+ connectivity['coherence'][j, i] = connectivity['coherence'][i, j]
100
+
101
+ return connectivity
102
+
103
+ def _calculate_statistics(self, data: np.ndarray) -> Dict:
104
+ """Calculate statistical features for each channel"""
105
+ stats = {
106
+ 'mean': np.mean(data, axis=1),
107
+ 'std': np.std(data, axis=1),
108
+ 'skewness': self._calculate_skewness(data),
109
+ 'kurtosis': self._calculate_kurtosis(data),
110
+ 'hjorth': self._calculate_hjorth_parameters(data)
111
+ }
112
+ return stats
113
+
114
+ def _calculate_skewness(self, data: np.ndarray) -> np.ndarray:
115
+ """Calculate skewness for each channel"""
116
+ return np.array([signal.skew(channel) for channel in data])
117
+
118
+ def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray:
119
+ """Calculate kurtosis for each channel"""
120
+ return np.array([signal.kurtosis(channel) for channel in data])
121
+
122
+ def _calculate_hjorth_parameters(self, data: np.ndarray) -> Dict:
123
+ """Calculate Hjorth parameters (activity, mobility, complexity)"""
124
+ activity = np.var(data, axis=1)
125
+
126
+ # First derivative variance
127
+ diff1 = np.diff(data, axis=1)
128
+ mobility = np.sqrt(np.var(diff1, axis=1) / activity)
129
+
130
+ # Second derivative variance
131
+ diff2 = np.diff(diff1, axis=1)
132
+ complexity = np.sqrt(np.var(diff2, axis=1) / np.var(diff1, axis=1)) / mobility
133
+
134
+ return {
135
+ 'activity': activity,
136
+ 'mobility': mobility,
137
+ 'complexity': complexity
138
+ }
modules/treatment_planner.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, List, Optional
3
+ import json
4
+ from datetime import datetime, timedelta
5
+
6
+ class TreatmentPlanner:
7
+ def __init__(self):
8
+ self.treatment_protocols = self._load_treatment_protocols()
9
+ self.intervention_strategies = self._load_intervention_strategies()
10
+
11
+ def _load_treatment_protocols(self) -> Dict:
12
+ """Load predefined treatment protocols"""
13
+ return {
14
+ 'depression': {
15
+ 'psychotherapy': {
16
+ 'primary': 'Cognitive Behavioral Therapy (CBT)',
17
+ 'alternatives': [
18
+ 'Interpersonal Therapy (IPT)',
19
+ 'Behavioral Activation (BA)',
20
+ 'Mindfulness-Based Cognitive Therapy (MBCT)'
21
+ ],
22
+ 'duration': '12-16 weeks',
23
+ 'frequency': 'Weekly'
24
+ },
25
+ 'medication': {
26
+ 'classes': [
27
+ 'SSRIs',
28
+ 'SNRIs',
29
+ 'NDRIs',
30
+ 'Atypical antidepressants'
31
+ ],
32
+ 'duration': '6-12 months minimum',
33
+ 'monitoring': 'Every 2-4 weeks initially'
34
+ },
35
+ 'lifestyle': [
36
+ 'Regular exercise (30 minutes daily)',
37
+ 'Sleep hygiene improvement',
38
+ 'Social engagement activities',
39
+ 'Stress reduction techniques'
40
+ ]
41
+ },
42
+ 'anxiety': {
43
+ 'psychotherapy': {
44
+ 'primary': 'Cognitive Behavioral Therapy (CBT)',
45
+ 'alternatives': [
46
+ 'Exposure Therapy',
47
+ 'Acceptance and Commitment Therapy (ACT)',
48
+ 'Dialectical Behavior Therapy (DBT)'
49
+ ],
50
+ 'duration': '8-12 weeks',
51
+ 'frequency': 'Weekly'
52
+ },
53
+ 'medication': {
54
+ 'classes': [
55
+ 'SSRIs',
56
+ 'SNRIs',
57
+ 'Buspirone',
58
+ 'Beta-blockers'
59
+ ],
60
+ 'duration': 'As needed',
61
+ 'monitoring': 'Every 2-4 weeks initially'
62
+ },
63
+ 'lifestyle': [
64
+ 'Relaxation techniques',
65
+ 'Mindfulness meditation',
66
+ 'Regular exercise',
67
+ 'Stress management'
68
+ ]
69
+ },
70
+ 'ptsd': {
71
+ 'psychotherapy': {
72
+ 'primary': 'Trauma-Focused CBT',
73
+ 'alternatives': [
74
+ 'EMDR',
75
+ 'Prolonged Exposure Therapy',
76
+ 'Cognitive Processing Therapy'
77
+ ],
78
+ 'duration': '12-16 weeks',
79
+ 'frequency': 'Weekly'
80
+ },
81
+ 'medication': {
82
+ 'classes': [
83
+ 'SSRIs',
84
+ 'SNRIs',
85
+ 'Prazosin',
86
+ 'Antipsychotics'
87
+ ],
88
+ 'duration': '12 months minimum',
89
+ 'monitoring': 'Every 2-4 weeks initially'
90
+ },
91
+ 'lifestyle': [
92
+ 'Stress management techniques',
93
+ 'Sleep hygiene',
94
+ 'Grounding exercises',
95
+ 'Social support engagement'
96
+ ]
97
+ }
98
+ }
99
+
100
+ def _load_intervention_strategies(self) -> Dict:
101
+ """Load intervention strategies based on severity and symptoms"""
102
+ return {
103
+ 'mild': {
104
+ 'focus': 'Lifestyle modifications and psychoeducation',
105
+ 'monitoring': 'Monthly check-ins',
106
+ 'escalation_criteria': [
107
+ 'Symptom worsening',
108
+ 'Functional impairment',
109
+ 'Lack of improvement after 4-6 weeks'
110
+ ]
111
+ },
112
+ 'moderate': {
113
+ 'focus': 'Psychotherapy with optional medication',
114
+ 'monitoring': 'Bi-weekly check-ins',
115
+ 'escalation_criteria': [
116
+ 'Severe symptom exacerbation',
117
+ 'Development of risk factors',
118
+ 'Poor response to treatment'
119
+ ]
120
+ },
121
+ 'severe': {
122
+ 'focus': 'Intensive treatment with medication and therapy',
123
+ 'monitoring': 'Weekly check-ins',
124
+ 'escalation_criteria': [
125
+ 'Crisis development',
126
+ 'Safety concerns',
127
+ 'Treatment resistance'
128
+ ]
129
+ }
130
+ }
131
+
132
+ def generate_plan(self, analysis_results: Dict) -> Dict:
133
+ """Generate a comprehensive treatment plan based on analysis results"""
134
+ try:
135
+ # Extract relevant information from analysis
136
+ conditions = analysis_results['condition_probabilities']
137
+ severity = analysis_results['severity_assessment']
138
+
139
+ # Generate treatment plan components
140
+ treatment_plan = {
141
+ 'summary': self._generate_plan_summary(conditions, severity),
142
+ 'primary_interventions': self._select_primary_interventions(
143
+ conditions, severity
144
+ ),
145
+ 'therapy_recommendations': self._generate_therapy_recommendations(
146
+ conditions, severity
147
+ ),
148
+ 'medication_considerations': self._generate_medication_recommendations(
149
+ conditions, severity
150
+ ),
151
+ 'lifestyle_modifications': self._generate_lifestyle_recommendations(
152
+ conditions
153
+ ),
154
+ 'monitoring_plan': self._create_monitoring_plan(severity),
155
+ 'crisis_plan': self._create_crisis_plan(severity),
156
+ 'timeline': self._create_treatment_timeline(
157
+ conditions, severity
158
+ )
159
+ }
160
+
161
+ return treatment_plan
162
+
163
+ except Exception as e:
164
+ return {
165
+ 'error': f'Error generating treatment plan: {str(e)}',
166
+ 'recommendations': self._generate_fallback_recommendations()
167
+ }
168
+
169
+ def _generate_plan_summary(
170
+ self, conditions: Dict, severity: Dict
171
+ ) -> Dict:
172
+ """Generate a summary of the treatment plan"""
173
+ primary_condition = max(conditions.items(), key=lambda x: x[1])[0]
174
+ severity_level = self._determine_severity_level(severity['overall_severity'])
175
+
176
+ return {
177
+ 'primary_condition': primary_condition,
178
+ 'severity_level': severity_level,
179
+ 'treatment_approach': self.intervention_strategies[severity_level]['focus'],
180
+ 'estimated_duration': self._estimate_treatment_duration(
181
+ primary_condition, severity_level
182
+ )
183
+ }
184
+
185
+ def _select_primary_interventions(
186
+ self, conditions: Dict, severity: Dict
187
+ ) -> List[str]:
188
+ """Select primary interventions based on conditions and severity"""
189
+ interventions = []
190
+ severity_level = self._determine_severity_level(severity['overall_severity'])
191
+
192
+ for condition, probability in conditions.items():
193
+ if probability > 0.4: # Consider conditions with significant probability
194
+ protocol = self.treatment_protocols[condition]
195
+
196
+ # Add primary therapy
197
+ interventions.append(
198
+ f"Primary therapy: {protocol['psychotherapy']['primary']}"
199
+ )
200
+
201
+ # Add medication if moderate to severe
202
+ if severity_level in ['moderate', 'severe']:
203
+ interventions.append(
204
+ f"Consider medication: {', '.join(protocol['medication']['classes'][:2])}"
205
+ )
206
+
207
+ return interventions
208
+
209
+ def _generate_therapy_recommendations(
210
+ self, conditions: Dict, severity: Dict
211
+ ) -> Dict:
212
+ """Generate specific therapy recommendations"""
213
+ therapy_plan = {}
214
+ severity_level = self._determine_severity_level(severity['overall_severity'])
215
+
216
+ for condition, probability in conditions.items():
217
+ if probability > 0.3:
218
+ protocol = self.treatment_protocols[condition]['psychotherapy']
219
+ therapy_plan[condition] = {
220
+ 'primary_therapy': protocol['primary'],
221
+ 'alternatives': protocol['alternatives'][:2],
222
+ 'frequency': protocol['frequency'],
223
+ 'duration': protocol['duration']
224
+ }
225
+
226
+ return therapy_plan
227
+
228
+ def _generate_medication_recommendations(
229
+ self, conditions: Dict, severity: Dict
230
+ ) -> Dict:
231
+ """Generate medication recommendations"""
232
+ medication_plan = {}
233
+ severity_level = self._determine_severity_level(severity['overall_severity'])
234
+
235
+ if severity_level in ['moderate', 'severe']:
236
+ for condition, probability in conditions.items():
237
+ if probability > 0.4:
238
+ protocol = self.treatment_protocols[condition]['medication']
239
+ medication_plan[condition] = {
240
+ 'recommended_classes': protocol['classes'][:2],
241
+ 'duration': protocol['duration'],
242
+ 'monitoring': protocol['monitoring']
243
+ }
244
+
245
+ return medication_plan
246
+
247
+ def _generate_lifestyle_recommendations(self, conditions: Dict) -> List[str]:
248
+ """Generate lifestyle modification recommendations"""
249
+ recommendations = set()
250
+
251
+ for condition, probability in conditions.items():
252
+ if probability > 0.3:
253
+ recommendations.update(
254
+ self.treatment_protocols[condition]['lifestyle']
255
+ )
256
+
257
+ return list(recommendations)
258
+
259
+ def _create_monitoring_plan(self, severity: Dict) -> Dict:
260
+ """Create a monitoring plan based on severity"""
261
+ severity_level = self._determine_severity_level(severity['overall_severity'])
262
+ strategy = self.intervention_strategies[severity_level]
263
+
264
+ return {
265
+ 'frequency': strategy['monitoring'],
266
+ 'focus_areas': [
267
+ 'Symptom severity',
268
+ 'Treatment response',
269
+ 'Side effects',
270
+ 'Functional improvement'
271
+ ],
272
+ 'escalation_criteria': strategy['escalation_criteria']
273
+ }
274
+
275
+ def _create_crisis_plan(self, severity: Dict) -> Dict:
276
+ """Create a crisis intervention plan"""
277
+ return {
278
+ 'warning_signs': [
279
+ 'Suicidal ideation',
280
+ 'Severe anxiety attacks',
281
+ 'Dissociative episodes',
282
+ 'Severe mood changes'
283
+ ],
284
+ 'emergency_contacts': [
285
+ 'Primary therapist',
286
+ 'Crisis hotline',
287
+ 'Emergency services (911)',
288
+ 'Trusted support person'
289
+ ],
290
+ 'immediate_actions': [
291
+ 'Contact emergency services if in immediate danger',
292
+ 'Use prescribed crisis medication if available',
293
+ 'Apply learned coping strategies',
294
+ 'Reach out to support system'
295
+ ]
296
+ }
297
+
298
+ def _create_treatment_timeline(
299
+ self, conditions: Dict, severity: Dict
300
+ ) -> List[Dict]:
301
+ """Create a timeline for treatment implementation"""
302
+ timeline = []
303
+ start_date = datetime.now()
304
+
305
+ # Initial phase
306
+ timeline.append({
307
+ 'phase': 'Initial Assessment and Stabilization',
308
+ 'duration': '1-2 weeks',
309
+ 'start_date': start_date.strftime('%Y-%m-%d'),
310
+ 'focus': 'Assessment and immediate interventions'
311
+ })
312
+
313
+ # Acute phase
314
+ acute_start = start_date + timedelta(weeks=2)
315
+ timeline.append({
316
+ 'phase': 'Acute Treatment',
317
+ 'duration': '8-12 weeks',
318
+ 'start_date': acute_start.strftime('%Y-%m-%d'),
319
+ 'focus': 'Primary interventions and symptom reduction'
320
+ })
321
+
322
+ # Continuation phase
323
+ continuation_start = acute_start + timedelta(weeks=12)
324
+ timeline.append({
325
+ 'phase': 'Continuation',
326
+ 'duration': '4-6 months',
327
+ 'start_date': continuation_start.strftime('%Y-%m-%d'),
328
+ 'focus': 'Maintaining improvements and preventing relapse'
329
+ })
330
+
331
+ return timeline
332
+
333
+ def _determine_severity_level(self, severity_score: float) -> str:
334
+ """Determine severity level from score"""
335
+ if severity_score > 0.7:
336
+ return 'severe'
337
+ elif severity_score > 0.4:
338
+ return 'moderate'
339
+ else:
340
+ return 'mild'
341
+
342
+ def _estimate_treatment_duration(
343
+ self, condition: str, severity_level: str
344
+ ) -> str:
345
+ """Estimate treatment duration based on condition and severity"""
346
+ base_duration = {
347
+ 'mild': 3,
348
+ 'moderate': 6,
349
+ 'severe': 12
350
+ }
351
+
352
+ months = base_duration[severity_level]
353
+ return f"{months}-{months+3} months"
354
+
355
+ def _generate_fallback_recommendations(self) -> List[str]:
356
+ """Generate basic recommendations when full plan generation fails"""
357
+ return [
358
+ "Seek professional mental health evaluation",
359
+ "Consider psychotherapy",
360
+ "Maintain regular sleep schedule",
361
+ "Practice stress management techniques",
362
+ "Engage in regular physical activity",
363
+ "Build and maintain social support network"
364
+ ]