Spaces:
Paused
Paused
from create_setfit_model import model | |
from time import perf_counter | |
import os | |
import sys | |
from statistics import mean | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import torch | |
from collections import Counter | |
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from tqdm import tqdm | |
start = perf_counter() | |
dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset')) | |
sys.path.append(dataset_dir) | |
from load_test_data import get_labels_df, get_texts | |
labels_dir = dataset_dir + '/csvs/' | |
df = get_labels_df(labels_dir) | |
texts_dir = dataset_dir + '/txts/' | |
texts = get_texts(texts_dir) | |
# df = df.iloc[:20, :] | |
# print(df.loc[:, 'Label']) | |
# texts = [texts[0]] + [texts[13]] + [texts[113]] | |
# texts = texts[:20] | |
print(len(df), len(texts)) | |
print(mean(list(map(len, texts)))) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=3200, chunk_overlap=200, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
labels = [] | |
pred_labels = [] | |
for text, (idx, (year, label, company)) in tqdm(zip(texts, df.iterrows())): | |
documents = text_splitter.create_documents([text]) | |
texts = [document.page_content for document in documents] | |
with torch.no_grad(): | |
model.model_head.eval() | |
text_pred_labels = model(texts) | |
pred_labels_counter = Counter(text_pred_labels) | |
pred_label = pred_labels_counter.most_common(1)[0][0] | |
labels.append(label) | |
pred_labels.append(pred_label) | |
accuracy = accuracy_score(labels, pred_labels) | |
precision = precision_score(labels, pred_labels, average='weighted') | |
recall = recall_score(labels, pred_labels, average='weighted') | |
f1 = f1_score(labels, pred_labels, average='weighted') | |
confusion_mat = confusion_matrix(labels, pred_labels, normalize='true') | |
print("Accuracy:", accuracy) | |
print("Precision:", precision) | |
print("Recall:", recall) | |
print("F1 Score:", f1) | |
labels = ['hold', 'buy', 'sell'] | |
plt.figure(figsize=(8, 6)) | |
sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
plt.xlabel('Predicted labels') | |
plt.ylabel('True labels') | |
plt.title('Confusion Matrix') | |
plt.show() | |
print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs') |