SANDRAMSC commited on
Commit
562987e
·
verified ·
1 Parent(s): 6a820a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -213
app.py CHANGED
@@ -1,213 +1,151 @@
1
- import time
2
- import streamlit as st
3
- from PyPDF2 import PdfReader
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from langchain.embeddings import OpenAIEmbeddings
6
- from langchain.vectorstores import FAISS
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.memory import ConversationBufferMemory
9
- from langchain.chains import ConversationalRetrievalChain
10
- import os
11
- import pickle
12
- from datetime import datetime
13
- from backend.generate_metadata import generate_metadata, ingest
14
-
15
-
16
- css = '''
17
- <style>
18
- .chat-message {
19
- padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
20
- }
21
- .chat-message.user {
22
- background-color: #2b313e
23
- }
24
- .chat-message.bot {
25
- background-color: #475063
26
- }
27
- .chat-message .avatar {
28
- width: 20%;
29
- }
30
- .chat-message .avatar img {
31
- max-width: 78px;
32
- max-height: 78px;
33
- border-radius: 50%;
34
- object-fit: cover;
35
- }
36
- .chat-message .message {
37
- width: 80%;
38
- padding: 0 1.5rem;
39
- color: #fff;
40
- }
41
- '''
42
- bot_template = '''
43
- <div class="chat-message bot">
44
- <div class="avatar">
45
- <img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png"
46
- style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;">
47
- </div>
48
- <div class="message">{{MSG}}</div>
49
- </div>
50
- '''
51
- user_template = '''
52
- <div class="chat-message user">
53
- <div class="avatar">
54
- <img src="https://i.ibb.co/rdZC7LZ/Photo-logo-1.png">
55
- </div>
56
- <div class="message">{{MSG}}</div>
57
- </div>
58
- '''
59
-
60
-
61
- def get_pdf_text(pdf_docs):
62
- text = ""
63
- for pdf in pdf_docs:
64
- pdf_reader = PdfReader(pdf)
65
- for page in pdf_reader.pages:
66
- text += page.extract_text()
67
- return text
68
-
69
-
70
- def get_text_chunks(text):
71
- text_splitter = CharacterTextSplitter(
72
- separator="\n",
73
- chunk_size=1000,
74
- chunk_overlap=200,
75
- length_function=len
76
- )
77
- chunks = text_splitter.split_text(text)
78
- return chunks
79
-
80
-
81
- def get_vectorstore(text_chunks):
82
- embeddings = OpenAIEmbeddings()
83
- # embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
84
- vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
85
- return vectorstore
86
-
87
-
88
- def get_conversation_chain(vectorstore):
89
- llm = ChatOpenAI()
90
- # llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.5, "max_length":512})
91
-
92
- memory = ConversationBufferMemory(
93
- memory_key='chat_history', return_messages=True)
94
- conversation_chain = ConversationalRetrievalChain.from_llm(
95
- llm=llm,
96
- retriever=vectorstore.as_retriever(),
97
- memory=memory
98
- )
99
- return conversation_chain
100
-
101
-
102
- def handle_userinput(user_question):
103
- response = st.session_state.conversation({'question': user_question})
104
- st.session_state.chat_history = response['chat_history']
105
-
106
- for i, message in enumerate(st.session_state.chat_history):
107
- # Display user message
108
- if i % 2 == 0:
109
- st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
110
- else:
111
- print(message)
112
- # Display AI response
113
- st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
114
-
115
-
116
- def safe_vec_store():
117
- # USE VECTARA INSTEAD
118
- os.makedirs('vectorstore', exist_ok=True)
119
- filename = 'vectors' + datetime.now().strftime('%Y%m%d%H%M') + '.pkl'
120
- file_path = os.path.join('vectorstore', filename)
121
- vector_store = st.session_state.vectorstore
122
-
123
- # Serialize and save the entire FAISS object using pickle
124
- with open(file_path, 'wb') as f:
125
- pickle.dump(vector_store, f)
126
-
127
-
128
- def main():
129
- st.set_page_config(page_title="Doc Verify RAG", page_icon=":mag:")
130
- st.write(css, unsafe_allow_html=True)
131
- st.header("Doc Verify RAG :mag:")
132
-
133
- if "openai_api_key" not in st.session_state:
134
- st.session_state.openai_api_key = False
135
- if "openai_org" not in st.session_state:
136
- st.session_state.openai_org = False
137
- if "classify" not in st.session_state:
138
- st.session_state.classify = False
139
-
140
- def set_pw():
141
- st.session_state.openai_api_key = True
142
-
143
- st.subheader("Your documents")
144
- OPENAI_API_KEY = st.text_input("OPENAI API KEY:", type="password",
145
- disabled=st.session_state.openai_api_key, on_change=set_pw)
146
- if st.session_state.classify:
147
- pdf_doc = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
148
- else:
149
- pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
150
- filenames = [file.name for file in pdf_docs if file is not None]
151
- if st.button("Process"):
152
- with st.spinner("Processing"):
153
- if st.session_state.classify:
154
- # THE CLASSIFICATION APP
155
- st.write("Classifying")
156
- plain_text_doc = ingest(pdf_doc.name)
157
- classification_result = generate_metadata(plain_text_doc)
158
- st.write(classification_result)
159
- else:
160
- # NORMAL RAG
161
- loaded_vec_store = None
162
- for filename in filenames:
163
- if ".pkl" in filename:
164
- file_path = os.path.join('vectorstore', filename)
165
- with open(file_path, 'rb') as f:
166
- loaded_vec_store = pickle.load(f)
167
- raw_text = get_pdf_text(pdf_docs)
168
- text_chunks = get_text_chunks(raw_text)
169
- vec = get_vectorstore(text_chunks)
170
- if loaded_vec_store:
171
- vec.merge_from(loaded_vec_store)
172
- st.warning("loaded vectorstore")
173
- if "vectorstore" in st.session_state:
174
- vec.merge_from(st.session_state.vectorstore)
175
- st.warning("merged to existing")
176
- st.session_state.vectorstore = vec
177
- st.session_state.conversation = get_conversation_chain(vec)
178
- st.success("data loaded")
179
-
180
- if "conversation" not in st.session_state:
181
- st.session_state.conversation = None
182
- if "chat_history" not in st.session_state:
183
- st.session_state.chat_history = None
184
-
185
- user_question = st.text_input("Ask a question about your documents:")
186
- if user_question:
187
- handle_userinput(user_question)
188
- with st.sidebar:
189
- st.subheader("Classification instructions")
190
- classifier_docs = st.file_uploader("Upload your instructions here and click on 'Process'",
191
- accept_multiple_files=True)
192
- filenames = [file.name for file in classifier_docs if file is not None]
193
-
194
- if st.button("Process Classification"):
195
- st.session_state.classify = True
196
- with st.spinner("Processing"):
197
- st.warning("set classify")
198
- time.sleep(3)
199
-
200
- if st.button("Save Embeddings"):
201
- if "vectorstore" in st.session_state:
202
- safe_vec_store()
203
- # st.session_state.vectorstore.save_local("faiss_index")
204
- st.sidebar.success("saved")
205
- else:
206
- st.sidebar.warning("No embeddings to save. Please process documents first.")
207
-
208
- if st.button("Load Embeddings"):
209
- st.warning("this function is not in use, just upload the vectorstore")
210
-
211
-
212
- if __name__ == '__main__':
213
- main()
 
1
+ import os
2
+ import io
3
+ import argparse
4
+ import json
5
+ import openai
6
+ import sys
7
+ from dotenv import load_dotenv
8
+ from langchain_community.document_loaders import TextLoader
9
+ from langchain_community.document_loaders import UnstructuredPDFLoader
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import Vectara
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.runnables import RunnablePassthrough
15
+ from langchain.prompts import PromptTemplate
16
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
17
+
18
+
19
+ load_dotenv()
20
+
21
+ MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
22
+
23
+ vectara_customer_id = os.environ['VECTARA_CUSTOMER_ID']
24
+ vectara_corpus_id = os.environ['VECTARA_CORPUS_ID']
25
+ vectara_api_key = os.environ['VECTARA_API_KEY']
26
+
27
+ embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
28
+
29
+ vectara = Vectara(vectara_customer_id=vectara_customer_id,
30
+ vectara_corpus_id=vectara_corpus_id,
31
+ vectara_api_key=vectara_api_key)
32
+
33
+
34
+ summary_config = {"is_enabled": True, "max_results": 3, "response_lang": "eng"}
35
+ retriever = vectara.as_retriever(
36
+ search_kwargs={"k": 3, "summary_config": summary_config}
37
+ )
38
+
39
+ template = """
40
+ passage: You are a helpful assistant that understands BIM building documents.
41
+ passage: You will analyze BIM document metadata composed of filename, description, and engineering discipline.
42
+ passage: The metadata is written in German.
43
+ passage: Filename: {filename}, Description: {description}, Engineering discipline: {discipline}.
44
+ query: Does the filename match other filenames within the same discipline?
45
+ query: Does the description match the engineering discipline?
46
+ query: How different is the metadata to your curated information?
47
+ query: Highligh any discrepancies and comment on wether or not the metadata is anomalous.
48
+ """
49
+
50
+ prompt = PromptTemplate(template=template, input_variables=['filename', 'description', 'discipline'])
51
+
52
+
53
+ def get_sources(documents):
54
+ return documents[:-1]
55
+
56
+ def get_summary(documents):
57
+ return documents[-1].page_content
58
+
59
+ def ingest(file_path):
60
+ extension = os.path.splitext(file_path)[1].lower()
61
+
62
+ if extension == '.pdf':
63
+ loader = UnstructuredPDFLoader(file_path)
64
+ elif extension == '.txt':
65
+ loader = TextLoader(file_path)
66
+ else:
67
+ raise NotImplementedError('Only .txt or .pdf files are supported')
68
+
69
+ # transform locally
70
+ documents = loader.load()
71
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0,
72
+ separators=[
73
+ "\n\n",
74
+ "\n",
75
+ " ",
76
+ ",",
77
+ "\uff0c", # Fullwidth comma
78
+ "\u3001", # Ideographic comma
79
+ "\uff0e", # Fullwidth full stop
80
+ # "\u200B", # Zero-width space (Asian languages)
81
+ # "\u3002", # Ideographic full stop (Asian languages)
82
+ "",
83
+ ])
84
+ docs = text_splitter.split_documents(documents)
85
+
86
+ return docs
87
+
88
+
89
+
90
+ def generate_metadata(docs):
91
+ prompt_template = """
92
+ BimDiscipline = ['plumbing', 'network', 'heating', 'electrical', 'ventilation', 'architecture']
93
+
94
+ You are a helpful assistant that understands BIM documents and engineering disciplines. Your answer should be in JSON format and only include the filename, a short description, and the engineering discipline the document belongs to, distinguishing between {[d.value for d in BimDiscipline]} based on the given document."
95
+
96
+ Analyze the provided document, which could be in either German or English. Extract the filename, its description, and infer the engineering discipline it belongs to. Document:
97
+ context="
98
+ """
99
+ # plain text
100
+ filepath = [doc.metadata for doc in docs][0]['source']
101
+ context = "".join(
102
+ [doc.page_content.replace('\n\n','').replace('..','') for doc in docs])
103
+
104
+ prompt = f'{prompt_template}{context}"\nFilepath:{filepath}'
105
+
106
+ #print(prompt)
107
+
108
+ # Create client
109
+ client = openai.OpenAI(
110
+ base_url="https://api.together.xyz/v1",
111
+ api_key=os.environ["TOGETHER_API_KEY"],
112
+ #api_key=userdata.get('TOGETHER_API_KEY'),
113
+ )
114
+
115
+ # Call the LLM with the JSON schema
116
+ chat_completion = client.chat.completions.create(
117
+ model=MODEL_NAME,
118
+ messages=[
119
+ {
120
+ "role": "system",
121
+ "content": f"You are a helpful assistant that responsds in JSON format"
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": prompt
126
+ }
127
+ ]
128
+ )
129
+
130
+ return json.loads(chat_completion.choices[0].message.content)
131
+
132
+
133
+ def analyze_metadata(filename, description, discipline):
134
+ formatted_prompt = prompt.format(filename=filename, description=description, discipline=discipline)
135
+ return (retriever | get_summary).invoke(formatted_prompt)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser(description="Generate metadata for a BIM document")
140
+ parser.add_argument("document", metavar="FILEPATH", type=str,
141
+ help="Path to the BIM document")
142
+
143
+ args = parser.parse_args()
144
+
145
+ if not os.path.exists(args.document) or not os.path.isfile(args.document):
146
+ print("File '{}' not found or not accessible.".format(args.document))
147
+ sys.exit(-1)
148
+
149
+ docs = ingest(args.document)
150
+ metadata = generate_metadata(docs)
151
+ print(metadata)