Spaces:
Paused
Paused
from statistics import mean | |
import sys | |
import os | |
import json | |
from datetime import datetime | |
import warnings | |
from pprint import pprint | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
warnings.filterwarnings("ignore") | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) | |
# sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'financial_dataset'))) | |
dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset')) | |
sys.path.append(dataset_dir) | |
from load_test_data import get_labels_df, get_texts | |
from app import ( | |
summarize, | |
read_and_split_file, | |
get_label_prediction | |
) | |
from config import ( | |
labels, headers_inference_api, headers_inference_endpoint, | |
# summarization_prompt_template, | |
prompt_template, | |
# task_explain_for_predictor_model, | |
summarizers, predictors, summary_scores_template, | |
summarization_system_msg, summarization_user_prompt, prediction_user_prompt, prediction_system_msg, | |
# prediction_prompt, | |
chat_prompt, instruction_prompt | |
) | |
def split_text(text, chunk_size=1200, chunk_overlap=200): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
text_chunks = text_splitter.create_documents([text]) | |
return text_chunks | |
predictions = { | |
# method: {name: {'actual': []}} | |
'summarization+classification': { | |
'bart-pegasus+gpt': [], # list of pred_labels | |
'gpt+gpt': [], | |
}, | |
'chunk_classification': {}, | |
'embedding_classification': {}, | |
'zero-shot_classification': {}, | |
'full_text_classification': {}, | |
'QA_classification': {} | |
} | |
# if __name__ == '__main__': | |
labels_dir = dataset_dir + '/csvs/' | |
df = get_labels_df(labels_dir) | |
texts_dir = dataset_dir + '/txts/' | |
texts = get_texts(texts_dir) | |
# print(len(df), len(texts)) | |
# print(mean(list(map(len, texts)))) | |
# summarization+classification | |
# for selected_summarizer in summarizers: | |
# print(selected_summarizer) | |
# # for selected_predictor in predictors: | |
# # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor] = [] | |
# for text, (idx, (year, label, company)) in zip(texts, df.iterrows()): | |
# print(year, label, company) | |
# # summary_filename = f'./texts/{year}_{company}_{selected_summarizer}_summary.txt' | |
# summary_filename = f'./texts/{company}_{year}_{selected_summarizer}_summary.txt' | |
# if os.path.isfile(summary_filename): | |
# print('Loading summary from the cache') | |
# with open(summary_filename, 'r') as f: | |
# summary = f.read() | |
# else: | |
# print(f'Making request to {selected_summarizer} to summarize {company}, {year}') | |
# text_chunks = split_text(text, | |
# chunk_size=summarizers[selected_summarizer]['chunk_size'], | |
# chunk_overlap=100) | |
# summary, summary_score = summarize(selected_summarizer, text_chunks) | |
# with open(summary_filename, 'w') as f: | |
# f.write(summary) | |
# print('-' * 50) | |
# # break | |
# # summary_chunks = split_text(summary, chunk_size=3_600) | |
# # predicted_label = get_label_prediction(selected_predictor, summary_chunks) | |
# # if predicted_label in labels: | |
# # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor].append(predicted_label) | |
# print() | |
# break | |
# # chunk_classification | |
# for selected_predictor in predictors: | |
# predictions['chunk_classification'][selected_predictor] = [] | |
# for text, (idx, (year, label, company)) in zip(texts, df.iterrows()): | |
# print(year, label, company) | |
# text_chunks = split_text(text, chunk_size=3600) | |
# predicted_label = get_label_prediction(selected_predictor, text_chunks) | |
# if predicted_label in labels: | |
# predictions['summarization+chunk_classification'][selected_predictor].append(predicted_label) | |
# print('-' * 50) | |
# with open(f'predictions/predictions_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.json', 'w') as json_file: | |
# json.dump(predictions, json_file, indent=4) |