Spaces:
Paused
Paused
def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200): | |
with open(filename, 'r') as f: | |
text = f.read() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
# st.write(f'Financial report char len: {len(text)}') | |
texts = text_splitter.create_documents([text]) | |
return texts | |
def get_label_prediction(selected_predictor, texts): | |
predicted_labels = [] | |
replies = [] | |
emdedding_model_name = predictors[selected_predictor]['embedding_model'] | |
emdedding_model = SentenceTransformer(emdedding_model_name) | |
texts_str = [text.page_content for text in texts] | |
embeddings = emdedding_model.encode(texts_str, show_progress_bar=True).tolist() | |
# dataset = load_dataset(predictors[selected_predictor]['dataset_name']) | |
label_encoder = LabelEncoder() | |
encoded_labels = label_encoder.fit_transform([label.upper() for label in labels]) | |
input_size = predictors[selected_predictor]['embedding_dim'] | |
hidden_size = 256 | |
output_size = len(label_encoder.classes_) | |
dropout_rate = 0.5 | |
batch_size = 8 | |
model = MLP(input_size, hidden_size, output_size, dropout_rate) | |
load_model(model, predictors[selected_predictor]['mlp_model']) | |
embeddings_tensor = torch.tensor(embeddings) | |
data = TensorDataset(embeddings_tensor) | |
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True) | |
with torch.no_grad(): | |
model.eval() | |
for inputs in dataloader: | |
# st.write(inputs[0]) | |
outputs = model(inputs[0]) | |
# _, predicted = torch.max(outputs, 1) | |
probabilities = F.softmax(outputs, dim=1) | |
predicted_indices = torch.argmax(probabilities, dim=1).tolist() | |
predicted_labels_list = label_encoder.inverse_transform(predicted_indices) | |
for pred_label in predicted_labels_list: | |
predicted_labels.append(pred_label) | |
# st.write(pred_label) | |
predicted_labels_counter = Counter(predicted_labels) | |
predicted_label = predicted_labels_counter.most_common(1)[0][0] | |
return predicted_label | |
if __name__ == '__main__': | |
# Comments and ideas to implement: | |
# 1. Try sending list of inputs to the Inference API. | |
from config import ( | |
labels, headers_inference_api, headers_inference_endpoint, | |
# summarization_prompt_template, | |
prompt_template, | |
# task_explain_for_predictor_model, | |
summarizers, predictors, summary_scores_template, | |
summarization_system_msg, summarization_user_prompt, prediction_user_prompt, prediction_system_msg, | |
# prediction_prompt, | |
chat_prompt, instruction_prompt | |
) | |
import streamlit as st | |
from sys import exit | |
from pprint import pprint | |
from collections import Counter | |
from itertools import zip_longest | |
from random import choice | |
import requests | |
from re import sub | |
from rouge import Rouge | |
from time import sleep, perf_counter | |
import os | |
from textwrap import wrap | |
from multiprocessing import Pool, freeze_support | |
from tqdm import tqdm | |
from stqdm import stqdm | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema.document import Document | |
# from langchain.schema import Document | |
from langchain.chat_models import ChatOpenAI | |
from langchain.llms import OpenAI | |
from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
from langchain.prompts import PromptTemplate | |
from datasets import Dataset, load_dataset | |
from sklearn.preprocessing import LabelEncoder | |
from test_models.train_classificator import MLP | |
from safetensors.torch import load_model, save_model | |
from sentence_transformers import SentenceTransformer | |
from torch.utils.data import DataLoader, TensorDataset | |
import torch.nn.functional as F | |
import torch | |
import torch.nn as nn | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/'))) | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta'))) | |
st.set_page_config( | |
page_title="Financial advisor", | |
page_icon="๐ณ๐ฐ", | |
layout="wide", | |
) | |
# st.session_state.summarized = False | |
with st.sidebar: | |
"# How to use๐" | |
""" | |
โจThis is a holiday version of the web-UI with the magic ๐, allowing you to unwrap | |
label predictions for a company based on its financial report text! ๐โจ The prediction | |
enchantment is performed using the sophisticated embedding classifier approach. ๐๐ฎ | |
""" | |
center_style = "<h3 style='text-align: center; color: black;'>{} </h3>" | |
st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True) | |
upload_types = ['Text input', 'File upload'] | |
upload_captions = ['Paste the text', 'Upload a text file'] | |
upload_type = st.radio('Select how to upload the financial report', upload_types, | |
captions=upload_captions) | |
match upload_type: | |
case 'Text input': | |
financial_report_text = st.text_area('Something', label_visibility='collapsed', | |
placeholder='Financial report as TEXT') | |
case 'File upload': | |
uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'], | |
label_visibility='collapsed', accept_multiple_files=True) | |
if not bool(uploaded_files): | |
st.stop() | |
financial_report_text = '' | |
for uploaded_file in uploaded_files: | |
if uploaded_file.name.endswith("docx"): | |
document = Document(uploaded_file) | |
document.save('./utils/texts/' + uploaded_file.name) | |
document = Document(uploaded_file.name) | |
financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n' | |
else: | |
financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n' | |
# with open('./utils/texts/financial_report_text.txt', 'w') as file: | |
# file.write(financial_report_text) | |
if st.button('Get label'): | |
with st.spinner("Thinking..."): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=3200, chunk_overlap=200, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
# st.write(f'Financial report char len: {len(financial_report_text)}') | |
documents = text_splitter.create_documents([financial_report_text]) | |
# st.write(f'Num chunks: {len(documents)}') | |
texts = [document.page_content for document in documents] | |
# st.write(f'Each chunk char length: {[len(text) for text in texts]}') | |
# predicted_label = get_label_prediction(texts) | |
from test_models.create_setfit_model import model | |
with torch.no_grad(): | |
model.model_head.eval() | |
predicted_labels = model(texts) | |
# st.write(predicted_labels) | |
predicted_labels_counter = Counter(predicted_labels) | |
predicted_label = predicted_labels_counter.most_common(1)[0][0] | |
font_style = 'The predicted label is<span style="font-size: 32px"> **{}**</span>.' | |
st.markdown(font_style.format(predicted_label), unsafe_allow_html=True) |