from create_setfit_model import model from time import perf_counter import os import sys from statistics import mean from langchain.text_splitter import RecursiveCharacterTextSplitter import torch from collections import Counter from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm start = perf_counter() dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset')) sys.path.append(dataset_dir) from load_test_data import get_labels_df, get_texts labels_dir = dataset_dir + '/csvs/' df = get_labels_df(labels_dir) texts_dir = dataset_dir + '/txts/' texts = get_texts(texts_dir) # df = df.iloc[:20, :] # print(df.loc[:, 'Label']) # texts = [texts[0]] + [texts[13]] + [texts[113]] # texts = texts[:20] print(len(df), len(texts)) print(mean(list(map(len, texts)))) text_splitter = RecursiveCharacterTextSplitter( chunk_size=3200, chunk_overlap=200, length_function = len, separators=[" ", ",", "\n"] ) labels = [] pred_labels = [] for text, (idx, (year, label, company)) in tqdm(zip(texts, df.iterrows())): documents = text_splitter.create_documents([text]) texts = [document.page_content for document in documents] with torch.no_grad(): model.model_head.eval() text_pred_labels = model(texts) pred_labels_counter = Counter(text_pred_labels) pred_label = pred_labels_counter.most_common(1)[0][0] labels.append(label) pred_labels.append(pred_label) accuracy = accuracy_score(labels, pred_labels) precision = precision_score(labels, pred_labels, average='weighted') recall = recall_score(labels, pred_labels, average='weighted') f1 = f1_score(labels, pred_labels, average='weighted') confusion_mat = confusion_matrix(labels, pred_labels, normalize='true') print("Accuracy:", accuracy) print("Precision:", precision) print("Recall:", recall) print("F1 Score:", f1) labels = ['hold', 'buy', 'sell'] plt.figure(figsize=(8, 6)) sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted labels') plt.ylabel('True labels') plt.title('Confusion Matrix') plt.show() print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs')