eogreen's picture
Upload app.py
dbe410e verified
raw
history blame
5.47 kB
# Import the necessary Libraries
from warnings import filterwarnings
filterwarnings('ignore')
import os
import uuid
import json
import gradio as gr
import pandas as pd
from huggingface_hub import CommitScheduler
from pathlib import Path
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import OpenAI
# Create Client
import os
os.environ['OPENAI_API_KEY'] = "gl-U2FsdGVkX1+0bNWD6YsVLZUYsn0m1WfLxUzrP0xUFbtWFAfk9Z1Cz+mD8u1yqKtV"; # e.g. gl-U2FsdGVkX19oG1mRO+LGAiNeC7nAeU8M65G4I6bfcdI7+9GUEjFFbplKq48J83by
os.environ["OPENAI_BASE_URL"] = "https://aibe.mygreatlearning.com/openai/v1" # e.g. "https://aibe.mygreatlearning.com/openai/v1";
llm_client = OpenAI()
# Define the embedding model and the vectorstore
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
# Load the persisted vectorDB
vectorstore_persisted = Chroma(
collection_name='10k_reports',
persist_directory='10k_reports_db',
embedding_function=embedding_model
)
#
##
#
# Prepare the logging functionality
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="eric-green-rag-financial-analyst",
repo_type="dataset",
folder_path=log_folder,
path_in_repo="data",
every=2
)
# Define the Q&A system message
# Create a system message for the LLM
qna_system_message = """
You are an assistant to a tech industry financial analyst. Your task is to provide relevant information about a set of companies AWS, Google, IBM, Meta, Microsoft.
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
The context contains references to specific portions of documents relevant to the user's query, along with source links.
The source for a context will begin with the token ###Source.
When crafting your response:
1. Select only context relevant to answer the question.
2. Include the source links in your response.
3. User questions will begin with the token: ###Question.
4. If the question is irrelevant to financial report information for the 5 companies, respond with "I am unable to locate relevent information. I answer questions related to the financial performance of AWS, Google, IBM, Meta and Microsoft."
Please adhere to the following guidelines:
- Your response should only be about the question asked and nothing else.
- Answer only using the context provided.
- Do not mention anything about the context in your final answer.
- If the answer is not found in the context, it is very very important for you to respond with "I am unable to locate a relevent answer."
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
Here is an example of how to structure your response:
Answer:
[Answer]
Source:
[Source]
"""
# Define the user message template
# Create a message template
qna_user_message_template = """
###Context
{context}
###Question
{question}
"""
# Define the llm_query function that runs when 'Submit' is clicked or when a API request is made
def llm_query(user_input,company):
filter = "dataset/"+company+"-10-k-2023.pdf"
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
# 1 - Create context_for_query
context_list = [d.page_content + "\n ###Source: " + str(d.metadata['page']) + "\n\n " for d in relevant_document_chunks]
context_for_query = ". ".join(context_list)
# 2 - Create messages
prompt = [
{'role':'system', 'content': qna_system_message},
{'role': 'user', 'content': qna_user_message_template.format(
context=context_for_query,
question=user_input
)
}
]
# Get response from the LLM
try:
response = llm_client.chat.completions.create(
model=model_name,
messages=prompt,
temperature=0
)
prediction = response.choices[0].message.content.strip()
except Exception as e:
prediction = f'Sorry, I encountered the following error: \n {e}'
print(prediction)
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
return prediction
# Set-up the Gradio UI
company = gr.Radio(Label='Company:', choices=["aws", "google", "ibm", "meta", "microsoft"]) # Create a radio button for company selection
textbox = gr.Textbox(Label='Question:') # Create a textbox for user input
# Create Gradio interface
# For the inputs parameter of Interface provide [textbox,company] with outputs parameter of Interface provide prediction
demo = gr.Interface(fn=llm_query, inputs=[textbox, company], outputs="text")
demo.queue()
demo.launch()