Spaces:
Runtime error
Runtime error
from transformers import ( | |
WhisperProcessor, WhisperForConditionalGeneration, | |
AutoModelForSequenceClassification, AutoTokenizer | |
) | |
import torch | |
class ModelManager: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.models = {} | |
self.tokenizers = {} | |
self.processors = {} | |
def load_models(self): | |
# Load Whisper for speech recognition | |
self.processors['whisper'] = WhisperProcessor.from_pretrained("openai/whisper-base") | |
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained( | |
"openai/whisper-base" | |
).to(self.device) | |
# Load EmoBERTa for emotion detection | |
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained("arpanghoshal/EmoRoBERTa") | |
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained( | |
"arpanghoshal/EmoRoBERTa" | |
).to(self.device) | |
# Load ClinicalBERT for analysis | |
self.tokenizers['clinical'] = AutoTokenizer.from_pretrained( | |
"emilyalsentzer/Bio_ClinicalBERT" | |
) | |
self.models['clinical'] = AutoModelForSequenceClassification.from_pretrained( | |
"emilyalsentzer/Bio_ClinicalBERT" | |
).to(self.device) | |
def transcribe(self, audio_input): | |
inputs = self.processors['whisper']( | |
audio_input, | |
return_tensors="pt" | |
).input_features.to(self.device) | |
generated_ids = self.models['whisper'].generate(inputs) | |
transcription = self.processors['whisper'].batch_decode( | |
generated_ids, | |
skip_special_tokens=True | |
)[0] | |
return transcription | |
def analyze_emotions(self, text): | |
inputs = self.tokenizers['emotion']( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
outputs = self.models['emotion'](**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
emotions = ['anger', 'fear', 'joy', 'love', 'sadness', 'surprise'] | |
return {emotion: float(prob) for emotion, prob in zip(emotions, probs[0])} | |
def analyze_mental_health(self, text): | |
inputs = self.tokenizers['clinical']( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
outputs = self.models['clinical'](**inputs) | |
scores = torch.sigmoid(outputs.logits) | |
return { | |
'depression_risk': float(scores[0][0]), | |
'anxiety_risk': float(scores[0][1]), | |
'stress_level': float(scores[0][2]) | |
} |