Koli98's picture
Committed required files
485deec verified
raw
history blame contribute delete
No virus
1.46 kB
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score, classification_report
from sklearn.preprocessing import LabelEncoder
class TextClassifier:
def __init__(self, train_features, train_targets, test_features, test_targets):
self.train_features = train_features
self.train_targets = train_targets
self.test_features = test_features
self.test_targets = test_targets
self.model = None
self.classification_report = None
self.accuracy = None
self.precision = None
self.recall = None
self.f1 = None
def train(self) -> None:
raise NotImplementedError
def predict(self, text_samples:list, inverse_transform:bool=True) -> list:
raise NotImplementedError
def evaluate(self) -> dict:
predictions = self.predict(self.test_features, inverse_transform=False)
self.accuracy = accuracy_score(self.test_targets, predictions)
self.precision = precision_score(self.test_targets, predictions, average='weighted')
self.recall = recall_score(self.test_targets, predictions, average='weighted')
self.f1 = f1_score(self.test_targets, predictions, average='weighted')
self.classification_report = classification_report(self.test_targets, predictions)
return {'accuracy' : self.accuracy,
'precision' : self.precision,
'recall' : self.recall}