Spaces:
Paused
Paused
def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200): | |
with open(filename, 'r') as f: | |
text = f.read() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
# st.write(f'Financial report char len: {len(text)}') | |
texts = text_splitter.create_documents([text]) | |
return texts | |
if __name__ == '__main__': | |
# Comments and ideas to implement: | |
# 1. Try sending list of inputs to the Inference API. | |
import streamlit as st | |
from sys import exit | |
from pprint import pprint | |
from collections import Counter | |
from itertools import zip_longest | |
from random import choice | |
import requests | |
from re import sub | |
from rouge import Rouge | |
from time import sleep, perf_counter | |
import os | |
from textwrap import wrap | |
from multiprocessing import Pool, freeze_support | |
from tqdm import tqdm | |
from stqdm import stqdm | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema.document import Document | |
# from langchain.schema import Document | |
from langchain.chat_models import ChatOpenAI | |
from langchain.llms import OpenAI | |
from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
from langchain.prompts import PromptTemplate | |
from datasets import Dataset, load_dataset | |
from sklearn.preprocessing import LabelEncoder | |
from test_models.train_classificator import MLP | |
from safetensors.torch import load_model, save_model | |
from sentence_transformers import SentenceTransformer | |
from torch.utils.data import DataLoader, TensorDataset | |
import torch.nn.functional as F | |
import torch | |
import torch.nn as nn | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/'))) | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta'))) | |
st.set_page_config( | |
page_title="Financial advisor", | |
page_icon="๐ณ๐ฐ", | |
layout="wide", | |
) | |
# st.session_state.summarized = False | |
with st.sidebar: | |
"# How to use๐" | |
""" | |
โจThis is a holiday version of the web-UI with the magic ๐, allowing you to unwrap | |
label predictions for a company based on its financial report text! ๐โจ The prediction | |
enchantment is performed using the sophisticated embedding classifier approach. ๐๐ฎ | |
""" | |
center_style = "<h3 style='text-align: center; color: black;'>{} </h3>" | |
st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True) | |
upload_types = ['Text input', 'File upload'] | |
upload_captions = ['Paste the text', 'Upload a text file'] | |
upload_type = st.radio('Select how to upload the financial report', upload_types, | |
captions=upload_captions) | |
match upload_type: | |
case 'Text input': | |
financial_report_text = st.text_area('Something', label_visibility='collapsed', | |
placeholder='Financial report as TEXT') | |
case 'File upload': | |
uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'], | |
label_visibility='collapsed', accept_multiple_files=True) | |
if not bool(uploaded_files): | |
st.stop() | |
financial_report_text = '' | |
for uploaded_file in uploaded_files: | |
if uploaded_file.name.endswith("docx"): | |
document = Document(uploaded_file) | |
document.save('./utils/texts/' + uploaded_file.name) | |
document = Document(uploaded_file.name) | |
financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n' | |
else: | |
financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n' | |
# with open('./utils/texts/financial_report_text.txt', 'w') as file: | |
# file.write(financial_report_text) | |
if st.button('Get label'): | |
with st.spinner("Thinking..."): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=3200, chunk_overlap=200, | |
length_function = len, separators=[" ", ",", "\n"] | |
) | |
# st.write(f'Financial report char len: {len(financial_report_text)}') | |
documents = text_splitter.create_documents([financial_report_text]) | |
# st.write(f'Num chunks: {len(documents)}') | |
texts = [document.page_content for document in documents] | |
# st.write(f'Each chunk char length: {[len(text) for text in texts]}') | |
# predicted_label = get_label_prediction(texts) | |
from test_models.create_setfit_model import model | |
with torch.no_grad(): | |
model.model_head.eval() | |
predicted_labels = model(texts) | |
# st.write(predicted_labels) | |
predicted_labels_counter = Counter(predicted_labels) | |
predicted_label = predicted_labels_counter.most_common(1)[0][0] | |
font_style = 'The predicted label is<span style="font-size: 32px"> **{}**</span>.' | |
st.markdown(font_style.format(predicted_label), unsafe_allow_html=True) |