DrishtiSharma commited on
Commit
3238cf2
·
verified ·
1 Parent(s): 13627f6

Create app_w_patent_preview.py

Browse files
Files changed (1) hide show
  1. app_w_patent_preview.py +235 -0
app_w_patent_preview.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to-do: Enable downloading multiple patent PDFs via corresponding links
2
+ import sys
3
+ import os
4
+ import re
5
+ import shutil
6
+ import time
7
+ import fitz
8
+ import streamlit as st
9
+ import nltk
10
+ import tempfile
11
+ import subprocess
12
+
13
+ # Pin NLTK to version 3.9.1
14
+ REQUIRED_NLTK_VERSION = "3.9.1"
15
+ subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"])
16
+
17
+ # Set up temporary directory for NLTK resources
18
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
19
+ os.makedirs(nltk_data_path, exist_ok=True)
20
+ nltk.data.path.append(nltk_data_path)
21
+
22
+ # Download 'punkt_tab' for compatibility
23
+ try:
24
+ print("Ensuring NLTK 'punkt_tab' resource is downloaded...")
25
+ nltk.download("punkt_tab", download_dir=nltk_data_path)
26
+ except Exception as e:
27
+ print(f"Error downloading NLTK 'punkt_tab': {e}")
28
+ raise e
29
+
30
+ sys.path.append(os.path.abspath("."))
31
+ from langchain.chains import ConversationalRetrievalChain
32
+ from langchain.memory import ConversationBufferMemory
33
+ from langchain.llms import OpenAI
34
+ from langchain.document_loaders import UnstructuredPDFLoader
35
+ from langchain.vectorstores import Chroma
36
+ from langchain.embeddings import HuggingFaceEmbeddings
37
+ from langchain.text_splitter import NLTKTextSplitter
38
+ from patent_downloader import PatentDownloader
39
+
40
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
41
+
42
+ # Fetch API key securely from the environment
43
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
44
+ if not OPENAI_API_KEY:
45
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
46
+ st.stop()
47
+
48
+ def check_poppler_installed():
49
+ if not shutil.which("pdfinfo"):
50
+ raise EnvironmentError(
51
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
52
+ )
53
+
54
+ check_poppler_installed()
55
+
56
+ def load_docs(document_path):
57
+ try:
58
+ loader = UnstructuredPDFLoader(
59
+ document_path,
60
+ mode="elements",
61
+ strategy="fast",
62
+ ocr_languages=None
63
+ )
64
+ documents = loader.load()
65
+ text_splitter = NLTKTextSplitter(chunk_size=1000)
66
+ split_docs = text_splitter.split_documents(documents)
67
+
68
+ # Filter metadata to only include str, int, float, or bool
69
+ for doc in split_docs:
70
+ if hasattr(doc, "metadata") and isinstance(doc.metadata, dict):
71
+ doc.metadata = {
72
+ k: v for k, v in doc.metadata.items()
73
+ if isinstance(v, (str, int, float, bool))
74
+ }
75
+ return split_docs
76
+ except Exception as e:
77
+ st.error(f"Failed to load and process PDF: {e}")
78
+ st.stop()
79
+
80
+ def already_indexed(vectordb, file_name):
81
+ indexed_sources = set(
82
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
83
+ )
84
+ return file_name in indexed_sources
85
+
86
+ def load_chain(file_name=None):
87
+ loaded_patent = st.session_state.get("LOADED_PATENT")
88
+
89
+ vectordb = Chroma(
90
+ persist_directory=PERSISTED_DIRECTORY,
91
+ embedding_function=HuggingFaceEmbeddings(),
92
+ )
93
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
94
+ st.write("✅ Already indexed.")
95
+ else:
96
+ vectordb.delete_collection()
97
+ docs = load_docs(file_name)
98
+ st.write("🔍 Number of Documents: ", len(docs))
99
+
100
+ vectordb = Chroma.from_documents(
101
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
102
+ )
103
+ vectordb.persist()
104
+ st.session_state["LOADED_PATENT"] = file_name
105
+
106
+ memory = ConversationBufferMemory(
107
+ memory_key="chat_history",
108
+ return_messages=True,
109
+ input_key="question",
110
+ output_key="answer",
111
+ )
112
+ return ConversationalRetrievalChain.from_llm(
113
+ OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
114
+ vectordb.as_retriever(search_kwargs={"k": 3}),
115
+ return_source_documents=False,
116
+ memory=memory,
117
+ )
118
+
119
+ def extract_patent_number(url):
120
+ pattern = r"/patent/([A-Z]{2}\d+)"
121
+ match = re.search(pattern, url)
122
+ return match.group(1) if match else None
123
+
124
+ def download_pdf(patent_number):
125
+ try:
126
+ patent_downloader = PatentDownloader(verbose=True)
127
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
128
+ return output_path[0]
129
+ except Exception as e:
130
+ st.error(f"Failed to download patent PDF: {e}")
131
+ st.stop()
132
+
133
+ def preview_pdf(pdf_path):
134
+ """Generate and display the first page of the PDF as an image."""
135
+ try:
136
+ doc = fitz.open(pdf_path) # Open PDF
137
+ first_page = doc[0] # Extract the first page
138
+ pix = first_page.get_pixmap() # Render page to a Pixmap (image)
139
+ temp_image_path = os.path.join(tempfile.gettempdir(), "pdf_preview.png")
140
+ pix.save(temp_image_path) # Save the image temporarily
141
+ return temp_image_path
142
+ except Exception as e:
143
+ st.error(f"Error generating PDF preview: {e}")
144
+ return None
145
+
146
+ if __name__ == "__main__":
147
+ st.set_page_config(
148
+ page_title="Patent Chat: Google Patents Chat Demo",
149
+ page_icon="📖",
150
+ layout="wide",
151
+ initial_sidebar_state="expanded",
152
+ )
153
+ st.header("📖 Patent Chat: Google Patents Chat Demo")
154
+
155
+ # Fetch query parameters safely
156
+ query_params = st.query_params
157
+ default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en")
158
+
159
+ # Input for Google Patent Link
160
+ patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100)
161
+
162
+ # Button to start processing
163
+ if st.button("Load and Process Patent"):
164
+ if not patent_link:
165
+ st.warning("Please enter a Google patent link to proceed.")
166
+ st.stop()
167
+
168
+ # Extract patent number
169
+ patent_number = extract_patent_number(patent_link)
170
+ if not patent_number:
171
+ st.error("Invalid patent link format. Please provide a valid Google patent link.")
172
+ st.stop()
173
+
174
+ st.write(f"Patent number: **{patent_number}**")
175
+
176
+ # File download handling
177
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
178
+ if os.path.isfile(pdf_path):
179
+ st.write("✅ File already downloaded.")
180
+ else:
181
+ st.write("📥 Downloading patent file...")
182
+ pdf_path = download_pdf(patent_number)
183
+ st.write(f"✅ File downloaded: {pdf_path}")
184
+
185
+ # Generate and display PDF preview
186
+ st.write("🖼️ Generating PDF preview...")
187
+ preview_image_path = preview_pdf(pdf_path)
188
+
189
+ if preview_image_path:
190
+ st.image(preview_image_path, caption="First Page Preview", use_column_width=True)
191
+ else:
192
+ st.warning("Failed to generate a preview for this PDF.")
193
+
194
+ # Load the document into the system
195
+ st.write("🔄 Loading document into the system...")
196
+
197
+ # Persist the chain in session state to prevent reloading
198
+ if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path:
199
+ st.session_state.chain = load_chain(pdf_path)
200
+ st.session_state.loaded_file = pdf_path
201
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
202
+
203
+ st.success("🚀 Document successfully loaded! You can now start asking questions.")
204
+
205
+ # Initialize messages if not already done
206
+ if "messages" not in st.session_state:
207
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
208
+
209
+ # Display previous chat messages
210
+ for message in st.session_state.messages:
211
+ with st.chat_message(message["role"]):
212
+ st.markdown(message["content"])
213
+
214
+ # User input and chatbot response
215
+ if "chain" in st.session_state:
216
+ if user_input := st.chat_input("What is your question?"):
217
+ st.session_state.messages.append({"role": "user", "content": user_input})
218
+ with st.chat_message("user"):
219
+ st.markdown(user_input)
220
+
221
+ with st.chat_message("assistant"):
222
+ message_placeholder = st.empty()
223
+ full_response = ""
224
+
225
+ with st.spinner("Generating response..."):
226
+ try:
227
+ assistant_response = st.session_state.chain({"question": user_input})
228
+ full_response = assistant_response["answer"]
229
+ except Exception as e:
230
+ full_response = f"An error occurred: {e}"
231
+
232
+ message_placeholder.markdown(full_response)
233
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
234
+ else:
235
+ st.info("Press the 'Load and Process Patent' button to start processing.")