def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200): with open(filename, 'r') as f: text = f.read() text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function = len, separators=[" ", ",", "\n"] ) # st.write(f'Financial report char len: {len(text)}') texts = text_splitter.create_documents([text]) return texts def get_label_prediction(selected_predictor, texts): predicted_labels = [] replies = [] emdedding_model_name = predictors[selected_predictor]['embedding_model'] emdedding_model = SentenceTransformer(emdedding_model_name) texts_str = [text.page_content for text in texts] embeddings = emdedding_model.encode(texts_str, show_progress_bar=True).tolist() # dataset = load_dataset(predictors[selected_predictor]['dataset_name']) label_encoder = LabelEncoder() encoded_labels = label_encoder.fit_transform([label.upper() for label in labels]) input_size = predictors[selected_predictor]['embedding_dim'] hidden_size = 256 output_size = len(label_encoder.classes_) dropout_rate = 0.5 batch_size = 8 model = MLP(input_size, hidden_size, output_size, dropout_rate) load_model(model, predictors[selected_predictor]['mlp_model']) embeddings_tensor = torch.tensor(embeddings) data = TensorDataset(embeddings_tensor) dataloader = DataLoader(data, batch_size=batch_size, shuffle=True) with torch.no_grad(): model.eval() for inputs in dataloader: # st.write(inputs[0]) outputs = model(inputs[0]) # _, predicted = torch.max(outputs, 1) probabilities = F.softmax(outputs, dim=1) predicted_indices = torch.argmax(probabilities, dim=1).tolist() predicted_labels_list = label_encoder.inverse_transform(predicted_indices) for pred_label in predicted_labels_list: predicted_labels.append(pred_label) # st.write(pred_label) predicted_labels_counter = Counter(predicted_labels) predicted_label = predicted_labels_counter.most_common(1)[0][0] return predicted_label if __name__ == '__main__': # Comments and ideas to implement: # 1. Try sending list of inputs to the Inference API. from config import ( labels, headers_inference_api, headers_inference_endpoint, # summarization_prompt_template, prompt_template, # task_explain_for_predictor_model, summarizers, predictors, summary_scores_template, summarization_system_msg, summarization_user_prompt, prediction_user_prompt, prediction_system_msg, # prediction_prompt, chat_prompt, instruction_prompt ) import streamlit as st from sys import exit from pprint import pprint from collections import Counter from itertools import zip_longest from random import choice import requests from re import sub from rouge import Rouge from time import sleep, perf_counter import os from textwrap import wrap from multiprocessing import Pool, freeze_support from tqdm import tqdm from stqdm import stqdm from langchain.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema.document import Document # from langchain.schema import Document from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain.prompts import PromptTemplate from datasets import Dataset, load_dataset from sklearn.preprocessing import LabelEncoder from test_models.train_classificator import MLP from safetensors.torch import load_model, save_model from sentence_transformers import SentenceTransformer from torch.utils.data import DataLoader, TensorDataset import torch.nn.functional as F import torch import torch.nn as nn import sys sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/'))) sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta'))) st.set_page_config( page_title="Financial advisor", page_icon="๐ณ๐ฐ", layout="wide", ) # st.session_state.summarized = False with st.sidebar: "# How to use๐" """ โจThis is a holiday version of the web-UI with the magic ๐, allowing you to unwrap label predictions for a company based on its financial report text! ๐โจ The prediction enchantment is performed using the sophisticated embedding classifier approach. ๐๐ฎ """ center_style = "