lfoppiano commited on
Commit
eaa93c8
2 Parent(s): b3584a6 cbdc1a4

Merge branch 'main' into pdf-render

Browse files
README.md CHANGED
@@ -16,11 +16,14 @@ license: apache-2.0
16
 
17
  ## Introduction
18
 
19
- Question/Answering on scientific documents using LLMs (OpenAI, Mistral, ~~LLama2,~~ etc..).
20
- This application is the frontend for testing the RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS.
21
- Differently to most of the project, we focus on scientific articles. We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
 
22
 
23
- **NER in LLM response**: The responses from the LLMs are post-processed to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
 
 
24
 
25
  **Demos**:
26
  - (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
 
16
 
17
  ## Introduction
18
 
19
+ Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
20
+ The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
21
+ Differently to most of the projects, we focus on scientific articles.
22
+ We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
23
 
24
+ Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
25
+
26
+ The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.
27
 
28
  **Demos**:
29
  - (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
document_qa/document_qa_engine.py CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
- def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
 
 
 
 
 
 
27
  self.embedding_function = embedding_function
28
  self.llm = llm
29
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
@@ -81,14 +87,14 @@ class DocumentQAEngine:
81
  return self.embeddings_map_from_md5[md5]
82
 
83
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
84
- verbose=False) -> (
85
  Any, str):
86
  # self.load_embeddings(self.embeddings_root_path)
87
 
88
  if verbose:
89
  print(query)
90
 
91
- response = self._run_query(doc_id, query, context_size=context_size)
92
  response = response['output_text'] if 'output_text' in response else response
93
 
94
  if verbose:
@@ -138,9 +144,15 @@ class DocumentQAEngine:
138
 
139
  return parsed_output
140
 
141
- def _run_query(self, doc_id, query, context_size=4):
142
  relevant_documents = self._get_context(doc_id, query, context_size)
143
- return self.chain.run(input_documents=relevant_documents, question=query)
 
 
 
 
 
 
144
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
145
 
146
  def _get_context(self, doc_id, query, context_size=4):
@@ -150,6 +162,7 @@ class DocumentQAEngine:
150
  return relevant_documents
151
 
152
  def get_all_context_by_document(self, doc_id):
 
153
  db = self.embeddings_dict[doc_id]
154
  docs = db.get()
155
  return docs['documents']
@@ -161,6 +174,7 @@ class DocumentQAEngine:
161
  return relevant_documents
162
 
163
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
 
164
  if verbose:
165
  print("File", pdf_file_path)
166
  filename = Path(pdf_file_path).stem
@@ -209,18 +223,17 @@ class DocumentQAEngine:
209
 
210
  if hash not in self.embeddings_dict.keys():
211
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
212
- collection_name=hash)
213
  else:
214
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
215
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
216
  collection_name=hash)
217
 
218
-
219
  self.embeddings_root_path = None
220
 
221
  return hash
222
 
223
- def create_embeddings(self, pdfs_dir_path: Path):
224
  input_files = []
225
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
226
  for file_ in files:
@@ -238,7 +251,8 @@ class DocumentQAEngine:
238
  print(data_path, "exists. Skipping it ")
239
  continue
240
 
241
- texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
 
242
  filename = metadata[0]['filename']
243
 
244
  vector_db_document = Chroma.from_texts(texts,
 
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
+ def __init__(self,
27
+ llm,
28
+ embedding_function,
29
+ qa_chain_type="stuff",
30
+ embeddings_root_path=None,
31
+ grobid_url=None,
32
+ ):
33
  self.embedding_function = embedding_function
34
  self.llm = llm
35
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
 
87
  return self.embeddings_map_from_md5[md5]
88
 
89
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
90
+ verbose=False, memory=None) -> (
91
  Any, str):
92
  # self.load_embeddings(self.embeddings_root_path)
93
 
94
  if verbose:
95
  print(query)
96
 
97
+ response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
98
  response = response['output_text'] if 'output_text' in response else response
99
 
100
  if verbose:
 
144
 
145
  return parsed_output
146
 
147
+ def _run_query(self, doc_id, query, memory=None, context_size=4):
148
  relevant_documents = self._get_context(doc_id, query, context_size)
149
+ if memory:
150
+ return self.chain.run(input_documents=relevant_documents,
151
+ question=query)
152
+ else:
153
+ return self.chain.run(input_documents=relevant_documents,
154
+ question=query,
155
+ memory=memory)
156
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
157
 
158
  def _get_context(self, doc_id, query, context_size=4):
 
162
  return relevant_documents
163
 
164
  def get_all_context_by_document(self, doc_id):
165
+ """Return the full context from the document"""
166
  db = self.embeddings_dict[doc_id]
167
  docs = db.get()
168
  return docs['documents']
 
174
  return relevant_documents
175
 
176
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
177
+ """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
178
  if verbose:
179
  print("File", pdf_file_path)
180
  filename = Path(pdf_file_path).stem
 
223
 
224
  if hash not in self.embeddings_dict.keys():
225
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
226
+ collection_name=hash)
227
  else:
228
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
229
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
230
  collection_name=hash)
231
 
 
232
  self.embeddings_root_path = None
233
 
234
  return hash
235
 
236
+ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
237
  input_files = []
238
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
239
  for file_ in files:
 
251
  print(data_path, "exists. Skipping it ")
252
  continue
253
 
254
+ texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
255
+ perc_overlap=perc_overlap)
256
  filename = metadata[0]['filename']
257
 
258
  vector_db_document = Chroma.from_texts(texts,
pyproject.toml CHANGED
@@ -3,7 +3,7 @@ requires = ["setuptools", "setuptools-scm"]
3
  build-backend = "setuptools.build_meta"
4
 
5
  [tool.bumpversion]
6
- current_version = "0.2.1"
7
  commit = "true"
8
  tag = "true"
9
  tag_name = "v{new_version}"
 
3
  build-backend = "setuptools.build_meta"
4
 
5
  [tool.bumpversion]
6
+ current_version = "0.3.0"
7
  commit = "true"
8
  tag = "true"
9
  tag_name = "v{new_version}"
streamlit_app.py CHANGED
@@ -7,6 +7,7 @@ from tempfile import NamedTemporaryFile
7
  import dotenv
8
  from grobid_quantities.quantities import QuantitiesAPI
9
  from langchain.llms.huggingface_hub import HuggingFaceHub
 
10
 
11
  dotenv.load_dotenv(override=True)
12
 
@@ -52,6 +53,9 @@ if 'ner_processing' not in st.session_state:
52
  if 'uploaded' not in st.session_state:
53
  st.session_state['uploaded'] = False
54
 
 
 
 
55
  if 'binary' not in st.session_state:
56
  st.session_state['binary'] = None
57
 
@@ -82,6 +86,11 @@ def new_file():
82
  st.session_state['loaded_embeddings'] = None
83
  st.session_state['doc_id'] = None
84
  st.session_state['uploaded'] = True
 
 
 
 
 
85
 
86
 
87
  # @st.cache_resource
@@ -112,6 +121,7 @@ def init_qa(model, api_key=None):
112
  else:
113
  st.error("The model was not loaded properly. Try reloading. ")
114
  st.stop()
 
115
 
116
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
117
 
@@ -183,7 +193,7 @@ with st.sidebar:
183
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
184
 
185
  st.markdown(
186
- ":warning: Mistral and Zephyr are free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
187
 
188
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
189
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
@@ -219,6 +229,12 @@ with st.sidebar:
219
  st.session_state['rqa'][model] = init_qa(model)
220
  # else:
221
  # is_api_key_provided = st.session_state['api_key']
 
 
 
 
 
 
222
  left_column, right_column = st.columns([1, 1])
223
 
224
  with right_column:
@@ -349,7 +365,8 @@ with right_column:
349
  elif mode == "LLM":
350
  with st.spinner("Generating response..."):
351
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
352
- context_size=context_size)
 
353
 
354
  if not text_response:
355
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
@@ -368,5 +385,11 @@ with right_column:
368
  st.write(text_response)
369
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
370
 
 
 
 
 
 
 
371
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
372
  play_old_messages()
 
7
  import dotenv
8
  from grobid_quantities.quantities import QuantitiesAPI
9
  from langchain.llms.huggingface_hub import HuggingFaceHub
10
+ from langchain.memory import ConversationBufferWindowMemory
11
 
12
  dotenv.load_dotenv(override=True)
13
 
 
53
  if 'uploaded' not in st.session_state:
54
  st.session_state['uploaded'] = False
55
 
56
+ if 'memory' not in st.session_state:
57
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
58
+
59
  if 'binary' not in st.session_state:
60
  st.session_state['binary'] = None
61
 
 
86
  st.session_state['loaded_embeddings'] = None
87
  st.session_state['doc_id'] = None
88
  st.session_state['uploaded'] = True
89
+ st.session_state['memory'].clear()
90
+
91
+
92
+ def clear_memory():
93
+ st.session_state['memory'].clear()
94
 
95
 
96
  # @st.cache_resource
 
121
  else:
122
  st.error("The model was not loaded properly. Try reloading. ")
123
  st.stop()
124
+ return
125
 
126
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
127
 
 
193
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
194
 
195
  st.markdown(
196
+ ":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
197
 
198
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
199
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
 
229
  st.session_state['rqa'][model] = init_qa(model)
230
  # else:
231
  # is_api_key_provided = st.session_state['api_key']
232
+
233
+ st.button(
234
+ 'Reset chat memory.',
235
+ on_click=clear_memory(),
236
+ help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
237
+
238
  left_column, right_column = st.columns([1, 1])
239
 
240
  with right_column:
 
365
  elif mode == "LLM":
366
  with st.spinner("Generating response..."):
367
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
368
+ context_size=context_size,
369
+ memory=st.session_state.memory)
370
 
371
  if not text_response:
372
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
 
385
  st.write(text_response)
386
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
387
 
388
+ for id in range(0, len(st.session_state.messages), 2):
389
+ question = st.session_state.messages[id]['content']
390
+ if len(st.session_state.messages) > id + 1:
391
+ answer = st.session_state.messages[id + 1]['content']
392
+ st.session_state.memory.save_context({"input": question}, {"output": answer})
393
+
394
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
395
  play_old_messages()