holiday_testing / app.py
svystun-taras's picture
tested the model on all dataset
501f2e5
raw
history blame
5.51 kB
def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200):
with open(filename, 'r') as f:
text = f.read()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
length_function = len, separators=[" ", ",", "\n"]
)
# st.write(f'Financial report char len: {len(text)}')
texts = text_splitter.create_documents([text])
return texts
if __name__ == '__main__':
# Comments and ideas to implement:
# 1. Try sending list of inputs to the Inference API.
import streamlit as st
from sys import exit
from pprint import pprint
from collections import Counter
from itertools import zip_longest
from random import choice
import requests
from re import sub
from rouge import Rouge
from time import sleep, perf_counter
import os
from textwrap import wrap
from multiprocessing import Pool, freeze_support
from tqdm import tqdm
from stqdm import stqdm
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
# from langchain.schema import Document
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.prompts import PromptTemplate
from datasets import Dataset, load_dataset
from sklearn.preprocessing import LabelEncoder
from test_models.train_classificator import MLP
from safetensors.torch import load_model, save_model
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/')))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta')))
st.set_page_config(
page_title="Financial advisor",
page_icon="๐Ÿ’ณ๐Ÿ’ฐ",
layout="wide",
)
# st.session_state.summarized = False
with st.sidebar:
"# How to use๐Ÿ”"
"""
โœจThis is a holiday version of the web-UI with the magic ๐ŸŒ, allowing you to unwrap
label predictions for a company based on its financial report text! ๐Ÿ“Šโœจ The prediction
enchantment is performed using the sophisticated embedding classifier approach. ๐Ÿš€๐Ÿ”ฎ
"""
center_style = "<h3 style='text-align: center; color: black;'>{} </h3>"
st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True)
upload_types = ['Text input', 'File upload']
upload_captions = ['Paste the text', 'Upload a text file']
upload_type = st.radio('Select how to upload the financial report', upload_types,
captions=upload_captions)
match upload_type:
case 'Text input':
financial_report_text = st.text_area('Something', label_visibility='collapsed',
placeholder='Financial report as TEXT')
case 'File upload':
uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'],
label_visibility='collapsed', accept_multiple_files=True)
if not bool(uploaded_files):
st.stop()
financial_report_text = ''
for uploaded_file in uploaded_files:
if uploaded_file.name.endswith("docx"):
document = Document(uploaded_file)
document.save('./utils/texts/' + uploaded_file.name)
document = Document(uploaded_file.name)
financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n'
else:
financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n'
# with open('./utils/texts/financial_report_text.txt', 'w') as file:
# file.write(financial_report_text)
if st.button('Get label'):
with st.spinner("Thinking..."):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=3200, chunk_overlap=200,
length_function = len, separators=[" ", ",", "\n"]
)
# st.write(f'Financial report char len: {len(financial_report_text)}')
documents = text_splitter.create_documents([financial_report_text])
# st.write(f'Num chunks: {len(documents)}')
texts = [document.page_content for document in documents]
# st.write(f'Each chunk char length: {[len(text) for text in texts]}')
# predicted_label = get_label_prediction(texts)
from test_models.create_setfit_model import model
with torch.no_grad():
model.model_head.eval()
predicted_labels = model(texts)
# st.write(predicted_labels)
predicted_labels_counter = Counter(predicted_labels)
predicted_label = predicted_labels_counter.most_common(1)[0][0]
font_style = 'The predicted label is<span style="font-size: 32px"> **{}**</span>.'
st.markdown(font_style.format(predicted_label), unsafe_allow_html=True)