DrishtiSharma commited on
Commit
091b7b6
·
verified ·
1 Parent(s): e9c25b2

Update bad_app.py

Browse files
Files changed (1) hide show
  1. bad_app.py +237 -0
bad_app.py CHANGED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import re
4
+ import shutil
5
+ import time
6
+ import streamlit as st
7
+ import nltk
8
+ import tempfile
9
+ import subprocess
10
+ import base64 # For embedding PDF content
11
+
12
+ # Pin NLTK to version 3.9.1
13
+ REQUIRED_NLTK_VERSION = "3.9.1"
14
+ subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"])
15
+
16
+ # Set up temporary directory for NLTK resources
17
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
18
+ os.makedirs(nltk_data_path, exist_ok=True)
19
+ nltk.data.path.append(nltk_data_path)
20
+
21
+ # Download 'punkt_tab' for compatibility
22
+ try:
23
+ print("Ensuring NLTK 'punkt_tab' resource is downloaded...")
24
+ nltk.download("punkt_tab", download_dir=nltk_data_path)
25
+ except Exception as e:
26
+ print(f"Error downloading NLTK 'punkt_tab': {e}")
27
+ raise e
28
+
29
+ sys.path.append(os.path.abspath("."))
30
+ from langchain.chains import ConversationalRetrievalChain
31
+ from langchain.memory import ConversationBufferMemory
32
+ from langchain.llms import OpenAI
33
+ from langchain.document_loaders import UnstructuredPDFLoader
34
+ from langchain.vectorstores import Chroma
35
+ from langchain.embeddings import HuggingFaceEmbeddings
36
+ from langchain.text_splitter import NLTKTextSplitter
37
+ from patent_downloader import PatentDownloader
38
+
39
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
40
+
41
+ # Fetch API key securely from the environment
42
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
43
+ if not OPENAI_API_KEY:
44
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
45
+ st.stop()
46
+
47
+ def check_poppler_installed():
48
+ if not shutil.which("pdfinfo"):
49
+ raise EnvironmentError(
50
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
51
+ )
52
+
53
+ check_poppler_installed()
54
+
55
+ def load_docs(document_path):
56
+ try:
57
+ loader = UnstructuredPDFLoader(
58
+ document_path,
59
+ mode="elements",
60
+ strategy="fast",
61
+ ocr_languages=None
62
+ )
63
+ documents = loader.load()
64
+ text_splitter = NLTKTextSplitter(chunk_size=1000)
65
+ split_docs = text_splitter.split_documents(documents)
66
+
67
+ # Filter metadata to only include str, int, float, or bool
68
+ for doc in split_docs:
69
+ if hasattr(doc, "metadata") and isinstance(doc.metadata, dict):
70
+ doc.metadata = {
71
+ k: v for k, v in doc.metadata.items()
72
+ if isinstance(v, (str, int, float, bool))
73
+ }
74
+ return split_docs
75
+ except Exception as e:
76
+ st.error(f"Failed to load and process PDF: {e}")
77
+ st.stop()
78
+
79
+ def already_indexed(vectordb, file_name):
80
+ indexed_sources = set(
81
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
82
+ )
83
+ return file_name in indexed_sources
84
+
85
+ def load_chain(file_name=None):
86
+ loaded_patent = st.session_state.get("LOADED_PATENT")
87
+
88
+ vectordb = Chroma(
89
+ persist_directory=PERSISTED_DIRECTORY,
90
+ embedding_function=HuggingFaceEmbeddings(),
91
+ )
92
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
93
+ st.write("✅ Already indexed.")
94
+ else:
95
+ vectordb.delete_collection()
96
+ docs = load_docs(file_name)
97
+ st.write("🔍 Number of Documents: ", len(docs))
98
+
99
+ vectordb = Chroma.from_documents(
100
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
101
+ )
102
+ vectordb.persist()
103
+ st.session_state["LOADED_PATENT"] = file_name
104
+
105
+ memory = ConversationBufferMemory(
106
+ memory_key="chat_history",
107
+ return_messages=True,
108
+ input_key="question",
109
+ output_key="answer",
110
+ )
111
+ return ConversationalRetrievalChain.from_llm(
112
+ OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
113
+ vectordb.as_retriever(search_kwargs={"k": 3}),
114
+ return_source_documents=False,
115
+ memory=memory,
116
+ )
117
+
118
+ def extract_patent_number(url):
119
+ pattern = r"/patent/([A-Z]{2}\d+)"
120
+ match = re.search(pattern, url)
121
+ return match.group(1) if match else None
122
+
123
+ def download_pdf(patent_number):
124
+ try:
125
+ patent_downloader = PatentDownloader(verbose=True)
126
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
127
+ return output_path[0]
128
+ except Exception as e:
129
+ st.error(f"Failed to download patent PDF: {e}")
130
+ st.stop()
131
+
132
+ def embed_pdf(file_path):
133
+ """Convert PDF file to base64 and embed it in an iframe."""
134
+ with open(file_path, "rb") as f:
135
+ base64_pdf = base64.b64encode(f.read()).decode("utf-8")
136
+ pdf_display = f"""
137
+ <iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" style="border: none;"></iframe>
138
+ """
139
+ return pdf_display
140
+
141
+ if __name__ == "__main__":
142
+ st.set_page_config(
143
+ page_title="Patent Chat: Google Patents Chat Demo",
144
+ page_icon="📖",
145
+ layout="wide",
146
+ initial_sidebar_state="expanded",
147
+ )
148
+ st.header("���� Patent Chat: Google Patents Chat Demo")
149
+
150
+ # Fetch query parameters safely
151
+ query_params = st.query_params
152
+ default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en")
153
+
154
+ # Input for Google Patent Link
155
+ patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100)
156
+
157
+ # Button to start processing
158
+ if st.button("Load and Process Patent"):
159
+ if not patent_link:
160
+ st.warning("Please enter a Google patent link to proceed.")
161
+ st.stop()
162
+
163
+ patent_number = extract_patent_number(patent_link)
164
+ if not patent_number:
165
+ st.error("Invalid patent link format. Please provide a valid Google patent link.")
166
+ st.stop()
167
+
168
+ st.write(f"Patent number: **{patent_number}**")
169
+
170
+ # Define PDF path in temp directory
171
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
172
+ if os.path.isfile(pdf_path):
173
+ st.write("✅ File already downloaded.")
174
+ else:
175
+ st.write("📥 Downloading patent file...")
176
+ pdf_path = download_pdf(patent_number)
177
+ st.write(f"✅ File downloaded: {pdf_path}")
178
+
179
+ # Display a preview of the downloaded PDF
180
+ st.write("📄 Preview of the downloaded patent PDF:")
181
+ if os.path.isfile(pdf_path):
182
+ with open(pdf_path, "rb") as pdf_file:
183
+ st.download_button(
184
+ label="Download PDF",
185
+ data=pdf_file,
186
+ file_name=f"{patent_number}.pdf",
187
+ mime="application/pdf"
188
+ )
189
+ # Embed PDF content using base64
190
+ st.write("📋 PDF Content:")
191
+ pdf_view = embed_pdf(pdf_path)
192
+ st.components.v1.html(pdf_view, height=1000)
193
+
194
+ st.write("🔄 Loading document into the system...")
195
+
196
+ # Persist the chain in session state to prevent reloading
197
+ if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path:
198
+ st.session_state.chain = load_chain(pdf_path)
199
+ st.session_state.loaded_file = pdf_path
200
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
201
+
202
+ st.success("🚀 Document successfully loaded! You can now start asking questions.")
203
+
204
+ # Initialize messages if not already done
205
+ if "messages" not in st.session_state:
206
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
207
+
208
+ # Display previous chat messages
209
+ for message in st.session_state.messages:
210
+ with st.chat_message(message["role"]):
211
+ st.markdown(message["content"])
212
+
213
+ # Chat Input Section
214
+ if "chain" in st.session_state:
215
+ if user_input := st.chat_input("What is your question?"):
216
+ # Append user message
217
+ st.session_state.messages.append({"role": "user", "content": user_input})
218
+ with st.chat_message("user"):
219
+ st.markdown(user_input)
220
+
221
+ # Generate assistant response
222
+ with st.chat_message("assistant"):
223
+ message_placeholder = st.empty()
224
+ full_response = ""
225
+
226
+ with st.spinner("Generating response..."):
227
+ try:
228
+ assistant_response = st.session_state.chain({"question": user_input})
229
+ full_response = assistant_response["answer"]
230
+ except Exception as e:
231
+ full_response = f"An error occurred: {e}"
232
+
233
+ # Display assistant response
234
+ message_placeholder.markdown(full_response)
235
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
236
+ else:
237
+ st.info("Press the 'Load and Process Patent' button to start processing.")