invincible-jha's picture
Upload 4 files
d894230 verified
from transformers import (
WhisperProcessor, WhisperForConditionalGeneration,
AutoModelForSequenceClassification, AutoTokenizer
)
import torch
class ModelManager:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.models = {}
self.tokenizers = {}
self.processors = {}
def load_models(self):
# Load Whisper for speech recognition
self.processors['whisper'] = WhisperProcessor.from_pretrained("openai/whisper-base")
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-base"
).to(self.device)
# Load EmoBERTa for emotion detection
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained("arpanghoshal/EmoRoBERTa")
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
"arpanghoshal/EmoRoBERTa"
).to(self.device)
# Load ClinicalBERT for analysis
self.tokenizers['clinical'] = AutoTokenizer.from_pretrained(
"emilyalsentzer/Bio_ClinicalBERT"
)
self.models['clinical'] = AutoModelForSequenceClassification.from_pretrained(
"emilyalsentzer/Bio_ClinicalBERT"
).to(self.device)
def transcribe(self, audio_input):
inputs = self.processors['whisper'](
audio_input,
return_tensors="pt"
).input_features.to(self.device)
generated_ids = self.models['whisper'].generate(inputs)
transcription = self.processors['whisper'].batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
return transcription
def analyze_emotions(self, text):
inputs = self.tokenizers['emotion'](
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = self.models['emotion'](**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
emotions = ['anger', 'fear', 'joy', 'love', 'sadness', 'surprise']
return {emotion: float(prob) for emotion, prob in zip(emotions, probs[0])}
def analyze_mental_health(self, text):
inputs = self.tokenizers['clinical'](
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = self.models['clinical'](**inputs)
scores = torch.sigmoid(outputs.logits)
return {
'depression_risk': float(scores[0][0]),
'anxiety_risk': float(scores[0][1]),
'stress_level': float(scores[0][2])
}