File size: 5,511 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200):
    with open(filename, 'r') as f:
        text = f.read()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap,
        length_function = len, separators=[" ", ",", "\n"]
    )

    # st.write(f'Financial report char len: {len(text)}')
    texts = text_splitter.create_documents([text])
    return texts







if __name__ == '__main__':
    # Comments and ideas to implement:
    # 1. Try sending list of inputs to the Inference API.


    import streamlit as st
    from sys import exit
    from pprint import pprint
    from collections import Counter
    from itertools import zip_longest
    from random import choice
    import requests
    from re import sub
    from rouge import Rouge
    from time import sleep, perf_counter
    import os
    from textwrap import wrap
    from multiprocessing import Pool, freeze_support
    from tqdm import tqdm
    from stqdm import stqdm
    from langchain.document_loaders import TextLoader
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    from langchain.schema.document import Document
    # from langchain.schema import Document
    from langchain.chat_models import ChatOpenAI
    from langchain.llms import OpenAI
    from langchain.schema import AIMessage, HumanMessage, SystemMessage
    from langchain.prompts import PromptTemplate
    from datasets import Dataset, load_dataset
    from sklearn.preprocessing import LabelEncoder
    from test_models.train_classificator import MLP
    from safetensors.torch import load_model, save_model
    from sentence_transformers import SentenceTransformer
    from torch.utils.data import DataLoader, TensorDataset
    import torch.nn.functional as F
    import torch
    import torch.nn as nn
    import sys
    
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/')))
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta')))

    st.set_page_config(
        page_title="Financial advisor",
        page_icon="💳💰",
        layout="wide",
    )
    # st.session_state.summarized = False


    



    with st.sidebar:
        "# How to use🔍"

        
        """
        ✨This is a holiday version of the web-UI with the magic 🌐, allowing you to unwrap
        label predictions for a company based on its financial report text! 📊✨ The prediction
        enchantment is performed using the sophisticated embedding classifier approach. 🚀🔮
        """


    center_style = "<h3 style='text-align: center; color: black;'>{} </h3>"
    st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True)


    upload_types = ['Text input', 'File upload']
    upload_captions = ['Paste the text', 'Upload a text file']
    upload_type = st.radio('Select how to upload the financial report', upload_types,
                        captions=upload_captions)


    match upload_type:
        case 'Text input':
            financial_report_text = st.text_area('Something', label_visibility='collapsed',
                                                placeholder='Financial report as TEXT')
            

        case 'File upload':
            uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'],
                            label_visibility='collapsed', accept_multiple_files=True)

            if not bool(uploaded_files):
                st.stop()

            financial_report_text = ''
            for uploaded_file in uploaded_files:
                if uploaded_file.name.endswith("docx"):
                    document = Document(uploaded_file)
                    document.save('./utils/texts/' + uploaded_file.name)
                    document = Document(uploaded_file.name)
                    financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n'
                else:
                    financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n'

    # with open('./utils/texts/financial_report_text.txt', 'w') as file:
    #     file.write(financial_report_text)

    if st.button('Get label'):
        with st.spinner("Thinking..."):
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=3200, chunk_overlap=200,
                length_function = len, separators=[" ", ",", "\n"]
            )

            # st.write(f'Financial report char len: {len(financial_report_text)}')
            documents = text_splitter.create_documents([financial_report_text])
            # st.write(f'Num chunks: {len(documents)}')
            texts = [document.page_content for document in documents]
            # st.write(f'Each chunk char length: {[len(text) for text in texts]}')

            # predicted_label = get_label_prediction(texts)
            from test_models.create_setfit_model import model
            
            with torch.no_grad():
                model.model_head.eval()
                predicted_labels = model(texts)
                # st.write(predicted_labels)
            
            predicted_labels_counter = Counter(predicted_labels)
            predicted_label = predicted_labels_counter.most_common(1)[0][0]
            
            font_style = 'The predicted label is<span style="font-size: 32px"> **{}**</span>.'
            st.markdown(font_style.format(predicted_label), unsafe_allow_html=True)