DrishtiSharma commited on
Commit
5ab2199
·
verified ·
1 Parent(s): 2375a67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to-do: Enable downloading multiple patent PDFs via corresponding links
2
+ import sys
3
+ import os
4
+ import re
5
+ import shutil
6
+ import time
7
+ import streamlit as st
8
+ import nltk
9
+ import tempfile
10
+ import subprocess
11
+
12
+ # Pin NLTK to version 3.9.1
13
+ REQUIRED_NLTK_VERSION = "3.9.1"
14
+ subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"])
15
+
16
+ # Set up temporary directory for NLTK resources
17
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
18
+ os.makedirs(nltk_data_path, exist_ok=True)
19
+ nltk.data.path.append(nltk_data_path)
20
+
21
+ # Download 'punkt_tab' for compatibility
22
+ try:
23
+ print("Ensuring NLTK 'punkt_tab' resource is downloaded...")
24
+ nltk.download("punkt_tab", download_dir=nltk_data_path)
25
+ except Exception as e:
26
+ print(f"Error downloading NLTK 'punkt_tab': {e}")
27
+ raise e
28
+
29
+ sys.path.append(os.path.abspath("."))
30
+ from langchain.chains import ConversationalRetrievalChain
31
+ from langchain.memory import ConversationBufferMemory
32
+ from langchain.llms import OpenAI
33
+ from langchain.document_loaders import UnstructuredPDFLoader
34
+ from langchain.vectorstores import Chroma
35
+ from langchain.embeddings import HuggingFaceEmbeddings
36
+ from langchain.text_splitter import NLTKTextSplitter
37
+ from patent_downloader import PatentDownloader
38
+
39
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
40
+
41
+ # Fetch API key securely from the environment
42
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
43
+ if not OPENAI_API_KEY:
44
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
45
+ st.stop()
46
+
47
+ def check_poppler_installed():
48
+ if not shutil.which("pdfinfo"):
49
+ raise EnvironmentError(
50
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
51
+ )
52
+
53
+ check_poppler_installed()
54
+
55
+ def load_docs(document_path):
56
+ try:
57
+ loader = UnstructuredPDFLoader(
58
+ document_path,
59
+ mode="elements",
60
+ strategy="fast",
61
+ ocr_languages=None
62
+ )
63
+ documents = loader.load()
64
+ text_splitter = NLTKTextSplitter(chunk_size=1000)
65
+ split_docs = text_splitter.split_documents(documents)
66
+
67
+ # Filter metadata to only include str, int, float, or bool
68
+ for doc in split_docs:
69
+ if hasattr(doc, "metadata") and isinstance(doc.metadata, dict):
70
+ doc.metadata = {
71
+ k: v for k, v in doc.metadata.items()
72
+ if isinstance(v, (str, int, float, bool))
73
+ }
74
+ return split_docs
75
+ except Exception as e:
76
+ st.error(f"Failed to load and process PDF: {e}")
77
+ st.stop()
78
+
79
+ def already_indexed(vectordb, file_name):
80
+ indexed_sources = set(
81
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
82
+ )
83
+ return file_name in indexed_sources
84
+
85
+ def load_chain(file_name=None):
86
+ loaded_patent = st.session_state.get("LOADED_PATENT")
87
+
88
+ vectordb = Chroma(
89
+ persist_directory=PERSISTED_DIRECTORY,
90
+ embedding_function=HuggingFaceEmbeddings(),
91
+ )
92
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
93
+ st.write("✅ Already indexed.")
94
+ else:
95
+ vectordb.delete_collection()
96
+ docs = load_docs(file_name)
97
+ st.write("🔍 Number of Documents: ", len(docs))
98
+
99
+ vectordb = Chroma.from_documents(
100
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
101
+ )
102
+ vectordb.persist()
103
+ st.session_state["LOADED_PATENT"] = file_name
104
+
105
+ memory = ConversationBufferMemory(
106
+ memory_key="chat_history",
107
+ return_messages=True,
108
+ input_key="question",
109
+ output_key="answer",
110
+ )
111
+ return ConversationalRetrievalChain.from_llm(
112
+ OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
113
+ vectordb.as_retriever(search_kwargs={"k": 3}),
114
+ return_source_documents=False,
115
+ memory=memory,
116
+ )
117
+
118
+ def extract_patent_number(url):
119
+ pattern = r"/patent/([A-Z]{2}\d+)"
120
+ match = re.search(pattern, url)
121
+ return match.group(1) if match else None
122
+
123
+ def download_pdf(patent_number):
124
+ try:
125
+ patent_downloader = PatentDownloader(verbose=True)
126
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
127
+ return output_path[0]
128
+ except Exception as e:
129
+ st.error(f"Failed to download patent PDF: {e}")
130
+ st.stop()
131
+
132
+ if __name__ == "__main__":
133
+ st.set_page_config(
134
+ page_title="Patent Chat: Google Patents Chat Demo",
135
+ page_icon="📖",
136
+ layout="wide",
137
+ initial_sidebar_state="expanded",
138
+ )
139
+ st.header("📖 Patent Chat: Google Patents Chat Demo")
140
+
141
+ # Fetch query parameters safely
142
+ query_params = st.query_params
143
+ default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en")
144
+
145
+ # Input for Google Patent Link
146
+ patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100)
147
+
148
+ # Button to start processing
149
+ if st.button("Load and Process Patent"):
150
+ if not patent_link:
151
+ st.warning("Please enter a Google patent link to proceed.")
152
+ st.stop()
153
+
154
+ patent_number = extract_patent_number(patent_link)
155
+ if not patent_number:
156
+ st.error("Invalid patent link format. Please provide a valid Google patent link.")
157
+ st.stop()
158
+
159
+ st.write(f"Patent number: **{patent_number}**")
160
+
161
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
162
+ if os.path.isfile(pdf_path):
163
+ st.write("✅ File already downloaded.")
164
+ else:
165
+ st.write("📥 Downloading patent file...")
166
+ pdf_path = download_pdf(patent_number)
167
+ st.write(f"✅ File downloaded: {pdf_path}")
168
+
169
+ st.write("🔄 Loading document into the system...")
170
+
171
+ # Persist the chain in session state to prevent reloading
172
+ if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path:
173
+ st.session_state.chain = load_chain(pdf_path)
174
+ st.session_state.loaded_file = pdf_path
175
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
176
+
177
+ st.success("🚀 Document successfully loaded! You can now start asking questions.")
178
+
179
+ # Initialize messages if not already done
180
+ if "messages" not in st.session_state:
181
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
182
+
183
+ # Display previous chat messages
184
+ for message in st.session_state.messages:
185
+ with st.chat_message(message["role"]):
186
+ st.markdown(message["content"])
187
+
188
+ if "chain" in st.session_state:
189
+ if user_input := st.chat_input("What is your question?"):
190
+ st.session_state.messages.append({"role": "user", "content": user_input})
191
+ with st.chat_message("user"):
192
+ st.markdown(user_input)
193
+
194
+ with st.chat_message("assistant"):
195
+ message_placeholder = st.empty()
196
+ full_response = ""
197
+
198
+ with st.spinner("Generating response..."):
199
+ try:
200
+ assistant_response = st.session_state.chain({"question": user_input})
201
+ full_response = assistant_response["answer"]
202
+ except Exception as e:
203
+ full_response = f"An error occurred: {e}"
204
+
205
+ message_placeholder.markdown(full_response)
206
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
207
+ else:
208
+ st.info("Press the 'Load and Process Patent' button to start processing.")