|
import numpy as np |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline |
|
from typing import Dict, List, Optional |
|
import torch |
|
import json |
|
|
|
class ClinicalAnalyzer: |
|
def __init__(self): |
|
self.initialize_models() |
|
self.condition_patterns = self._load_condition_patterns() |
|
|
|
def initialize_models(self): |
|
"""Initialize transformer models for clinical analysis""" |
|
try: |
|
|
|
self.clinical_tokenizer = AutoTokenizer.from_pretrained( |
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" |
|
) |
|
self.clinical_model = AutoModelForSequenceClassification.from_pretrained( |
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" |
|
) |
|
|
|
|
|
self.mental_health_pipeline = pipeline( |
|
"text-classification", |
|
model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
|
tokenizer="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" |
|
) |
|
except Exception as e: |
|
print(f"Error initializing models: {str(e)}") |
|
|
|
self.clinical_model = None |
|
self.clinical_tokenizer = None |
|
self.mental_health_pipeline = None |
|
|
|
def _load_condition_patterns(self) -> Dict: |
|
"""Load predefined patterns for mental health conditions""" |
|
return { |
|
'depression': { |
|
'eeg_patterns': { |
|
'alpha_asymmetry': True, |
|
'theta_increase': True, |
|
'beta_decrease': True |
|
}, |
|
'keywords': [ |
|
'depressed mood', 'loss of interest', 'fatigue', |
|
'sleep disturbance', 'concentration problems' |
|
] |
|
}, |
|
'anxiety': { |
|
'eeg_patterns': { |
|
'beta_increase': True, |
|
'alpha_decrease': True, |
|
'high_coherence': True |
|
}, |
|
'keywords': [ |
|
'anxiety', 'worry', 'restlessness', 'tension', |
|
'panic', 'nervousness' |
|
] |
|
}, |
|
'ptsd': { |
|
'eeg_patterns': { |
|
'alpha_suppression': True, |
|
'theta_increase': True, |
|
'beta_asymmetry': True |
|
}, |
|
'keywords': [ |
|
'trauma', 'flashbacks', 'nightmares', 'avoidance', |
|
'hypervigilance', 'startle response' |
|
] |
|
} |
|
} |
|
|
|
def analyze(self, features: Dict, clinical_notes: str) -> Dict: |
|
"""Perform comprehensive clinical analysis""" |
|
analysis_results = { |
|
'eeg_analysis': self._analyze_eeg_patterns(features), |
|
'text_analysis': self._analyze_clinical_text(clinical_notes), |
|
'condition_probabilities': self._calculate_condition_probabilities( |
|
features, clinical_notes |
|
), |
|
'severity_assessment': self._assess_severity(features, clinical_notes), |
|
'recommendations': self._generate_recommendations(features, clinical_notes) |
|
} |
|
return analysis_results |
|
|
|
def _analyze_eeg_patterns(self, features: Dict) -> Dict: |
|
"""Analyze EEG patterns for clinical significance""" |
|
eeg_analysis = {} |
|
|
|
|
|
band_powers = features['band_powers'] |
|
eeg_analysis['band_power_analysis'] = { |
|
band: { |
|
'mean': float(np.mean(powers)), |
|
'std': float(np.std(powers)), |
|
'clinical_significance': self._assess_band_significance(band, powers) |
|
} |
|
for band, powers in band_powers.items() |
|
} |
|
|
|
|
|
connectivity = features['connectivity'] |
|
eeg_analysis['connectivity_analysis'] = { |
|
'global_connectivity': float(np.mean(connectivity['correlation'])), |
|
'asymmetry_index': self._calculate_asymmetry_index(features) |
|
} |
|
|
|
return eeg_analysis |
|
|
|
def _analyze_clinical_text(self, clinical_notes: str) -> Dict: |
|
"""Analyze clinical notes using NLP""" |
|
if not clinical_notes: |
|
return {'error': 'No clinical notes provided'} |
|
|
|
try: |
|
if self.mental_health_pipeline: |
|
|
|
results = self.mental_health_pipeline(clinical_notes) |
|
text_analysis = { |
|
'sentiment': results[0]['label'], |
|
'confidence': float(results[0]['score']) |
|
} |
|
else: |
|
|
|
text_analysis = self._keyword_based_analysis(clinical_notes) |
|
|
|
|
|
text_analysis['identified_symptoms'] = self._extract_symptoms(clinical_notes) |
|
text_analysis['risk_factors'] = self._identify_risk_factors(clinical_notes) |
|
|
|
return text_analysis |
|
|
|
except Exception as e: |
|
return {'error': f'Text analysis failed: {str(e)}'} |
|
|
|
def _calculate_condition_probabilities( |
|
self, features: Dict, clinical_notes: str |
|
) -> Dict: |
|
"""Calculate probabilities for different mental health conditions""" |
|
probabilities = {} |
|
|
|
for condition, patterns in self.condition_patterns.items(): |
|
|
|
eeg_score = self._calculate_eeg_pattern_match( |
|
features, patterns['eeg_patterns'] |
|
) |
|
|
|
|
|
text_score = self._calculate_text_pattern_match( |
|
clinical_notes, patterns['keywords'] |
|
) |
|
|
|
|
|
combined_score = 0.6 * eeg_score + 0.4 * text_score |
|
probabilities[condition] = float(combined_score) |
|
|
|
return probabilities |
|
|
|
def _assess_severity(self, features: Dict, clinical_notes: str) -> Dict: |
|
"""Assess the severity of identified conditions""" |
|
severity = { |
|
'overall_severity': self._calculate_overall_severity(features, clinical_notes), |
|
'domain_severity': { |
|
'cognitive': self._assess_cognitive_severity(features), |
|
'emotional': self._assess_emotional_severity(features, clinical_notes), |
|
'behavioral': self._assess_behavioral_severity(clinical_notes) |
|
}, |
|
'risk_level': self._assess_risk_level(features, clinical_notes) |
|
} |
|
return severity |
|
|
|
def _generate_recommendations(self, features: Dict, clinical_notes: str) -> List[str]: |
|
"""Generate clinical recommendations based on analysis""" |
|
recommendations = [] |
|
|
|
|
|
severity = self._assess_severity(features, clinical_notes) |
|
conditions = self._calculate_condition_probabilities(features, clinical_notes) |
|
|
|
|
|
if severity['overall_severity'] > 0.7: |
|
recommendations.append("Immediate clinical intervention recommended") |
|
elif severity['overall_severity'] > 0.4: |
|
recommendations.append("Regular clinical monitoring recommended") |
|
|
|
|
|
for condition, probability in conditions.items(): |
|
if probability > 0.6: |
|
recommendations.extend( |
|
self._get_condition_specific_recommendations(condition) |
|
) |
|
|
|
return recommendations |
|
|
|
def _calculate_eeg_pattern_match(self, features: Dict, patterns: Dict) -> float: |
|
"""Calculate how well EEG features match condition patterns""" |
|
match_scores = [] |
|
|
|
for pattern, expected in patterns.items(): |
|
if pattern == 'alpha_asymmetry': |
|
score = self._check_alpha_asymmetry(features) |
|
elif pattern == 'beta_increase': |
|
score = self._check_beta_increase(features) |
|
elif pattern == 'theta_increase': |
|
score = self._check_theta_increase(features) |
|
else: |
|
score = 0.5 |
|
|
|
match_scores.append(score if expected else 1 - score) |
|
|
|
return np.mean(match_scores) if match_scores else 0.0 |
|
|
|
def _calculate_text_pattern_match(self, text: str, keywords: List[str]) -> float: |
|
"""Calculate how well clinical notes match condition keywords""" |
|
if not text: |
|
return 0.0 |
|
|
|
text = text.lower() |
|
matched_keywords = sum(1 for keyword in keywords if keyword.lower() in text) |
|
return matched_keywords / len(keywords) |
|
|
|
def _calculate_asymmetry_index(self, features: Dict) -> float: |
|
"""Calculate brain asymmetry index from EEG features""" |
|
try: |
|
|
|
alpha_powers = features['band_powers']['alpha'] |
|
left_channels = alpha_powers[:len(alpha_powers)//2] |
|
right_channels = alpha_powers[len(alpha_powers)//2:] |
|
|
|
asymmetry = np.log(np.mean(right_channels)) - np.log(np.mean(left_channels)) |
|
return float(asymmetry) |
|
except: |
|
return 0.0 |
|
|
|
def _assess_band_significance(self, band: str, powers: np.ndarray) -> str: |
|
"""Assess clinical significance of frequency band power""" |
|
mean_power = np.mean(powers) |
|
if band == 'alpha': |
|
if mean_power < 0.3: |
|
return "Significantly reduced alpha power" |
|
elif mean_power > 0.7: |
|
return "Elevated alpha power" |
|
elif band == 'beta': |
|
if mean_power > 0.7: |
|
return "Elevated beta power - possible anxiety" |
|
elif band == 'theta': |
|
if mean_power > 0.6: |
|
return "Elevated theta power - possible cognitive issues" |
|
|
|
return "Within normal range" |
|
|
|
def _get_condition_specific_recommendations(self, condition: str) -> List[str]: |
|
"""Get specific recommendations for identified conditions""" |
|
recommendations = { |
|
'depression': [ |
|
"Consider cognitive behavioral therapy", |
|
"Evaluate need for antidepressant medication", |
|
"Recommend regular physical activity", |
|
"Implement sleep hygiene practices" |
|
], |
|
'anxiety': [ |
|
"Consider anxiety-focused psychotherapy", |
|
"Evaluate need for anti-anxiety medication", |
|
"Recommend relaxation techniques", |
|
"Practice mindfulness meditation" |
|
], |
|
'ptsd': [ |
|
"Consider trauma-focused therapy", |
|
"Evaluate need for PTSD-specific medication", |
|
"Implement grounding techniques", |
|
"Develop safety and coping plans" |
|
] |
|
} |
|
return recommendations.get(condition, []) |