Early_Depression_Detection_V2 / sentiment_model.py
Koli98's picture
Committed required files
485deec verified
raw
history blame contribute delete
No virus
1.24 kB
from base_model import TextClassifier
import torch
from transformers import pipeline
class PretrainedSentimentAnalyzer(TextClassifier):
def __init__(self, train_features, train_targets, test_features, test_targets, min_threshold=0.7):
super().__init__(train_features, train_targets, test_features, test_targets)
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = pipeline("text-classification",
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
device=device)
self.prediction_map = {'positive' : 'positive',
'negative' : 'negative',
'neutral' : 'neutral'}
self.threshold = min_threshold
def train(self):
pass
def predict(self, text_samples:list, inverse_transform:bool, proba:bool=True) -> list:
predictions = self.model(text_samples, batch_size=128)
if proba:
return predictions
predictions = [self.prediction_map[prediction['label']] if prediction['score'] > self.threshold else 'neutral'
for prediction in predictions]
return predictions