Spaces:
Runtime error
Runtime error
from base_model import TextClassifier | |
import torch | |
from transformers import pipeline | |
class PretrainedSentimentAnalyzer(TextClassifier): | |
def __init__(self, train_features, train_targets, test_features, test_targets, min_threshold=0.7): | |
super().__init__(train_features, train_targets, test_features, test_targets) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = pipeline("text-classification", | |
model="cardiffnlp/twitter-roberta-base-sentiment-latest", | |
device=device) | |
self.prediction_map = {'positive' : 'positive', | |
'negative' : 'negative', | |
'neutral' : 'neutral'} | |
self.threshold = min_threshold | |
def train(self): | |
pass | |
def predict(self, text_samples:list, inverse_transform:bool, proba:bool=True) -> list: | |
predictions = self.model(text_samples, batch_size=128) | |
if proba: | |
return predictions | |
predictions = [self.prediction_map[prediction['label']] if prediction['score'] > self.threshold else 'neutral' | |
for prediction in predictions] | |
return predictions | |