eogreen commited on
Commit
51d6cf1
·
verified ·
1 Parent(s): 0a989f9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +155 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary Libraries
2
+ from warnings import filterwarnings
3
+ filterwarnings('ignore')
4
+ import os
5
+ import uuid
6
+ import json
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from huggingface_hub import CommitScheduler
10
+ from pathlib import Path
11
+ from langchain.embeddings import SentenceTransformerEmbeddings
12
+ from langchain.vectorstores import Chroma
13
+ from langchain.llms import OpenAI
14
+
15
+ # Create Client
16
+ import os
17
+ os.environ['OPENAI_API_KEY'] = "gl-U2FsdGVkX1+0bNWD6YsVLZUYsn0m1WfLxUzrP0xUFbtWFAfk9Z1Cz+mD8u1yqKtV"; # e.g. gl-U2FsdGVkX19oG1mRO+LGAiNeC7nAeU8M65G4I6bfcdI7+9GUEjFFbplKq48J83by
18
+ os.environ["OPENAI_BASE_URL"] = "https://aibe.mygreatlearning.com/openai/v1" # e.g. "https://aibe.mygreatlearning.com/openai/v1";
19
+
20
+ llm_client = OpenAI()
21
+
22
+ # Define the embedding model and the vectorstore
23
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
24
+ vectorstore_persisted = Chroma(
25
+ collection_name='10k_reports',
26
+ persist_directory='10k_reports_db',
27
+ embedding_function=embedding_model
28
+ )
29
+
30
+ # Load the persisted vectorDB
31
+ vectorstore_persisted.load()
32
+
33
+ #
34
+ ##
35
+ #
36
+
37
+ # Prepare the logging functionality
38
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
+ log_folder = log_file.parent
40
+
41
+ scheduler = CommitScheduler(
42
+ repo_id="eric-green-rag-financial-analyst",
43
+ repo_type="dataset",
44
+ folder_path=log_folder,
45
+ path_in_repo="data",
46
+ every=2
47
+ )
48
+
49
+ # Define the Q&A system message
50
+ # Create a system message for the LLM
51
+ qna_system_message = """
52
+ You are an assistant to a tech industry financial analyst. Your task is to provide relevant information about a set of companies AWS, Google, IBM, Meta, Microsoft.
53
+
54
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
55
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
56
+ The source for a context will begin with the token ###Source.
57
+
58
+ When crafting your response:
59
+ 1. Select only context relevant to answer the question.
60
+ 2. Include the source links in your response.
61
+ 3. User questions will begin with the token: ###Question.
62
+ 4. If the question is irrelevant to financial report information for the 5 companies, respond with "I am unable to locate relevent information. I answer questions related to the financial performance of AWS, Google, IBM, Meta and Microsoft."
63
+
64
+ Please adhere to the following guidelines:
65
+ - Your response should only be about the question asked and nothing else.
66
+ - Answer only using the context provided.
67
+ - Do not mention anything about the context in your final answer.
68
+ - If the answer is not found in the context, it is very very important for you to respond with "I am unable to locate a relevent answer."
69
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
70
+ - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
71
+
72
+ Here is an example of how to structure your response:
73
+
74
+ Answer:
75
+ [Answer]
76
+
77
+ Source:
78
+ [Source]
79
+ """
80
+
81
+ # Define the user message template
82
+ # Create a message template
83
+ qna_user_message_template = """
84
+ ###Context
85
+ {context}
86
+
87
+ ###Question
88
+ {question}
89
+ """
90
+
91
+ # Define the llm_query function that runs when 'Submit' is clicked or when a API request is made
92
+ def llm_query(user_input,company):
93
+
94
+ filter = "dataset/"+company+"-10-k-2023.pdf"
95
+
96
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
97
+
98
+ # 1 - Create context_for_query
99
+ context_list = [d.page_content + "\n ###Source: " + str(d.metadata['page']) + "\n\n " for d in relevant_document_chunks]
100
+
101
+ context_for_query = ". ".join(context_list)
102
+
103
+ # 2 - Create messages
104
+ prompt = [
105
+ {'role':'system', 'content': qna_system_message},
106
+ {'role': 'user', 'content': qna_user_message_template.format(
107
+ context=context_for_query,
108
+ question=user_input
109
+ )
110
+ }
111
+ ]
112
+
113
+ # Get response from the LLM
114
+ try:
115
+ response = llm_client.chat.completions.create(
116
+ model=model_name,
117
+ messages=prompt,
118
+ temperature=0
119
+ )
120
+
121
+ prediction = response.choices[0].message.content.strip()
122
+
123
+ except Exception as e:
124
+
125
+ prediction = f'Sorry, I encountered the following error: \n {e}'
126
+
127
+ print(prediction)
128
+
129
+ # While the prediction is made, log both the inputs and outputs to a local log file
130
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
131
+ # access
132
+
133
+ with scheduler.lock:
134
+ with log_file.open("a") as f:
135
+ f.write(json.dumps(
136
+ {
137
+ 'user_input': user_input,
138
+ 'retrieved_context': context_for_query,
139
+ 'model_response': prediction
140
+ }
141
+ ))
142
+ f.write("\n")
143
+
144
+ return prediction
145
+
146
+ # Set-up the Gradio UI
147
+ company = gr.Radio(Label='Company:', choices=["aws", "google", "ibm", "meta", "microsoft"]) # Create a radio button for company selection
148
+ textbox = gr.Textbox(Label='Question:') # Create a textbox for user input
149
+
150
+ # Create Gradio interface
151
+ # For the inputs parameter of Interface provide [textbox,company] with outputs parameter of Interface provide prediction
152
+ demo = gr.Interface(fn=llm_query, inputs=[textbox, company], outputs="text")
153
+
154
+ demo.queue()
155
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ f.write('openai==1.23.2\n')
2
+ f.write('tiktoken==0.6.0\n')
3
+ f.write('pypdf==4.0.1\n')
4
+ f.write('langchain==0.1.1\n')
5
+ f.write('langchain-community==0.0.13\n')
6
+ f.write('chromadb==0.4.22\n')
7
+ f.write('sentence-transformers==2.3.1\n')
8
+ f.write('gradio==3.23.0\n')
9
+
10
+ print('requirements.txt created!')