lfoppiano commited on
Commit
41ad70e
·
1 Parent(s): f684be7

return embeddings from storage retrieval

Browse files
Files changed (3) hide show
  1. document_qa/document_qa_engine.py +295 -77
  2. requirements.txt +11 -11
  3. streamlit_app.py +24 -20
document_qa/document_qa_engine.py CHANGED
@@ -1,23 +1,43 @@
1
  import copy
2
  import os
3
  from pathlib import Path
4
- from typing import Union, Any
5
 
6
  import tiktoken
7
- from grobid_client.grobid_client import GrobidClient
8
  from langchain.chains import create_extraction_chain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
14
- from langchain.vectorstores import Chroma
 
 
 
 
15
  from tqdm import tqdm
16
 
17
  from document_qa.grobid_processors import GrobidProcessor
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class TextMerger:
 
 
 
 
 
21
  def __init__(self, model_name=None, encoding_name="gpt2"):
22
  if model_name is not None:
23
  self.enc = tiktoken.encoding_for_model(model_name)
@@ -85,52 +105,187 @@ class TextMerger:
85
 
86
  return new_passages_struct
87
 
88
- class DataStorage:
89
 
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- class DocumentQAEngine:
93
- llm = None
94
- qa_chain_type = None
95
- embedding_function = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  embeddings_dict = {}
97
  embeddings_map_from_md5 = {}
98
  embeddings_map_to_md5 = {}
99
 
100
- default_prompts = {
101
- 'stuff': stuff_prompt,
102
- 'refine': refine_prompts,
103
- "map_reduce": map_reduce_prompt,
104
- "map_rerank": map_rerank_prompt
105
- }
106
-
107
- def __init__(self,
108
- llm,
109
- embedding_function,
110
- qa_chain_type="stuff",
111
- embeddings_root_path=None,
112
- grobid_url=None,
113
- memory=None
114
- ):
115
  self.embedding_function = embedding_function
116
- self.llm = llm
117
- self.memory = memory
118
- self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
119
- self.text_merger = TextMerger()
120
 
121
- if embeddings_root_path is not None:
122
- self.embeddings_root_path = embeddings_root_path
123
- if not os.path.exists(embeddings_root_path):
124
- os.makedirs(embeddings_root_path)
125
  else:
126
  self.load_embeddings(self.embeddings_root_path)
127
 
128
- if grobid_url:
129
- self.grobid_processor = GrobidProcessor(grobid_url)
130
-
131
  def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
132
  """
133
- Load the embeddings assuming they are all persisted and stored in a single directory.
134
  The root path of the embeddings containing one data store for each document in each subdirectory
135
  """
136
 
@@ -141,8 +296,10 @@ class DocumentQAEngine:
141
  return
142
 
143
  for embedding_document_dir in embeddings_directories:
144
- self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path,
145
- embedding_function=self.embedding_function)
 
 
146
 
147
  filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
148
  if filename_list:
@@ -161,9 +318,60 @@ class DocumentQAEngine:
161
  def get_filename_from_md5(self, md5):
162
  return self.embeddings_map_from_md5[md5]
163
 
164
- def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
165
- verbose=False) -> (
166
- Any, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # self.load_embeddings(self.embeddings_root_path)
168
 
169
  if verbose:
@@ -192,16 +400,22 @@ class DocumentQAEngine:
192
  else:
193
  return None, response, coordinates
194
 
195
- def query_storage(self, query: str, doc_id, context_size=4):
196
- documents = self._get_context(doc_id, query, context_size)
 
 
 
197
 
198
  context_as_text = [doc.page_content for doc in documents]
199
- return context_as_text
200
 
201
  def query_storage_and_embeddings(self, query: str, doc_id, context_size=4):
202
- db = self.embeddings_dict[doc_id]
203
- retriever = db.as_retriever(search_kwargs={"k": context_size})
204
- relevant_documents = retriever.get_relevant_documents(query, include=["embeddings"])
 
 
 
205
 
206
  context_as_text = [doc.page_content for doc in relevant_documents]
207
  return context_as_text
@@ -229,11 +443,11 @@ class DocumentQAEngine:
229
 
230
  return parsed_output
231
 
232
- def _run_query(self, doc_id, query, context_size=4):
233
  relevant_documents = self._get_context(doc_id, query, context_size)
234
  relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
235
  for doc in
236
- relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
237
  response = self.chain.run(input_documents=relevant_documents,
238
  question=query)
239
 
@@ -241,33 +455,40 @@ class DocumentQAEngine:
241
  self.memory.save_context({"input": query}, {"output": response})
242
  return response, relevant_document_coordinates
243
 
244
- def _get_context(self, doc_id, query, context_size=4):
245
- db = self.embeddings_dict[doc_id]
246
  retriever = db.as_retriever(search_kwargs={"k": context_size})
247
  relevant_documents = retriever.get_relevant_documents(query)
 
 
 
248
  if self.memory and len(self.memory.buffer_as_messages) > 0:
249
  relevant_documents.append(
250
  Document(
251
  page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format(
252
  self.memory.buffer_as_str))
253
  )
254
- return relevant_documents
255
 
256
- def get_all_context_by_document(self, doc_id):
257
- """Return the full context from the document"""
258
- db = self.embeddings_dict[doc_id]
 
 
259
  docs = db.get()
260
  return docs['documents']
261
 
262
  def _get_context_multiquery(self, doc_id, query, context_size=4):
263
- db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
264
  multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
265
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
266
  return relevant_documents
267
 
268
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
269
  """
270
- Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
 
 
271
  """
272
  if verbose:
273
  print("File", pdf_file_path)
@@ -307,7 +528,13 @@ class DocumentQAEngine:
307
 
308
  return texts, metadatas, ids
309
 
310
- def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
 
 
 
 
 
 
311
  texts, metadata, ids = self.get_text_from_document(
312
  pdf_path,
313
  chunk_size=chunk_size,
@@ -317,25 +544,17 @@ class DocumentQAEngine:
317
  else:
318
  hash = metadata[0]['hash']
319
 
320
- if hash not in self.embeddings_dict.keys():
321
- self.embeddings_dict[hash] = Chroma.from_texts(texts,
322
- embedding=self.embedding_function,
323
- metadatas=metadata,
324
- collection_name=hash)
325
- else:
326
- # if 'documents' in self.embeddings_dict[hash].get() and len(self.embeddings_dict[hash].get()['documents']) == 0:
327
- # self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
328
- self.embeddings_dict[hash].delete_collection()
329
- self.embeddings_dict[hash] = Chroma.from_texts(texts,
330
- embedding=self.embedding_function,
331
- metadatas=metadata,
332
- collection_name=hash)
333
-
334
- self.embeddings_root_path = None
335
 
336
  return hash
337
 
338
- def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
 
 
 
 
 
 
339
  input_files = []
340
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
341
  for file_ in files:
@@ -347,17 +566,16 @@ class DocumentQAEngine:
347
  desc="Grobid + embeddings processing"):
348
 
349
  md5 = self.calculate_md5(input_file)
350
- data_path = os.path.join(self.embeddings_root_path, md5)
351
 
352
  if os.path.exists(data_path):
353
  print(data_path, "exists. Skipping it ")
354
  continue
355
- include = ["biblio"] if include_biblio else []
356
  texts, metadata, ids = self.get_text_from_document(
357
  input_file,
358
  chunk_size=chunk_size,
359
- perc_overlap=perc_overlap,
360
- include=include)
361
  filename = metadata[0]['filename']
362
 
363
  vector_db_document = Chroma.from_texts(texts,
 
1
  import copy
2
  import os
3
  from pathlib import Path
4
+ from typing import Union, Any, Optional, List, Dict, Tuple, ClassVar, Collection
5
 
6
  import tiktoken
 
7
  from langchain.chains import create_extraction_chain
8
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
9
  map_rerank_prompt
10
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
11
  from langchain.retrievers import MultiQueryRetriever
12
  from langchain.schema import Document
13
+ from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
14
+ from langchain_community.vectorstores.faiss import FAISS
15
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
16
+ from langchain_core.utils import xor_args
17
+ from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
18
  from tqdm import tqdm
19
 
20
  from document_qa.grobid_processors import GrobidProcessor
21
 
22
 
23
+ def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
24
+ return [
25
+ (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
26
+ for result in zip(
27
+ results["documents"][0],
28
+ results["metadatas"][0],
29
+ results["distances"][0],
30
+ results["embeddings"][0],
31
+ )
32
+ ]
33
+
34
+
35
  class TextMerger:
36
+ """
37
+ This class tries to replicate the RecursiveTextSplitter from LangChain, to preserve and merge the
38
+ coordinate information from the PDF document.
39
+ """
40
+
41
  def __init__(self, model_name=None, encoding_name="gpt2"):
42
  if model_name is not None:
43
  self.enc = tiktoken.encoding_for_model(model_name)
 
105
 
106
  return new_passages_struct
107
 
 
108
 
109
+ class BaseRetrieval:
110
 
111
+ def __init__(
112
+ self,
113
+ persist_directory: Path,
114
+ embedding_function
115
+ ):
116
+ self.embedding_function = embedding_function
117
+ self.persist_directory = persist_directory
118
+
119
+
120
+ class AdvancedVectorStoreRetriever(VectorStoreRetriever):
121
+ allowed_search_types: ClassVar[Collection[str]] = (
122
+ "similarity",
123
+ "similarity_score_threshold",
124
+ "mmr",
125
+ "similarity_with_embeddings"
126
+ )
127
+
128
+ def _get_relevant_documents(
129
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
130
+ ) -> List[Document]:
131
+ if self.search_type == "similarity":
132
+ docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
133
+ elif self.search_type == "similarity_score_threshold":
134
+ docs_and_similarities = (
135
+ self.vectorstore.similarity_search_with_relevance_scores(
136
+ query, **self.search_kwargs
137
+ )
138
+ )
139
+ for doc, similarity in docs_and_similarities:
140
+ if '__similarity' not in doc.metadata.keys():
141
+ doc.metadata['__similarity'] = similarity
142
+
143
+ docs = [doc for doc, _ in docs_and_similarities]
144
+ elif self.search_type == "mmr":
145
+ docs = self.vectorstore.max_marginal_relevance_search(
146
+ query, **self.search_kwargs
147
+ )
148
+ elif self.search_type == "similarity_with_embeddings":
149
+ docs_scores_and_embeddings = (
150
+ self.vectorstore.advanced_similarity_search(
151
+ query, **self.search_kwargs
152
+ )
153
+ )
154
 
155
+ for doc, score, embeddings in docs_scores_and_embeddings:
156
+ if '__embeddings' not in doc.metadata.keys():
157
+ doc.metadata['__embeddings'] = embeddings
158
+ if '__similarity' not in doc.metadata.keys():
159
+ doc.metadata['__similarity'] = score
160
+
161
+ docs = [doc for doc, _, _ in docs_scores_and_embeddings]
162
+ else:
163
+ raise ValueError(f"search_type of {self.search_type} not allowed.")
164
+ return docs
165
+
166
+
167
+ class AdvancedVectorStore(VectorStore):
168
+ def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
169
+ tags = kwargs.pop("tags", None) or []
170
+ tags.extend(self._get_retriever_tags())
171
+ return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
172
+
173
+
174
+ class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
175
+ def __init__(self, **kwargs):
176
+ super().__init__(**kwargs)
177
+
178
+ @xor_args(("query_texts", "query_embeddings"))
179
+ def __query_collection(
180
+ self,
181
+ query_texts: Optional[List[str]] = None,
182
+ query_embeddings: Optional[List[List[float]]] = None,
183
+ n_results: int = 4,
184
+ where: Optional[Dict[str, str]] = None,
185
+ where_document: Optional[Dict[str, str]] = None,
186
+ **kwargs: Any,
187
+ ) -> List[Document]:
188
+ """Query the chroma collection."""
189
+ try:
190
+ import chromadb # noqa: F401
191
+ except ImportError:
192
+ raise ValueError(
193
+ "Could not import chromadb python package. "
194
+ "Please install it with `pip install chromadb`."
195
+ )
196
+ return self._collection.query(
197
+ query_texts=query_texts,
198
+ query_embeddings=query_embeddings,
199
+ n_results=n_results,
200
+ where=where,
201
+ where_document=where_document,
202
+ **kwargs,
203
+ )
204
+
205
+ def advanced_similarity_search(
206
+ self,
207
+ query: str,
208
+ k: int = DEFAULT_K,
209
+ filter: Optional[Dict[str, str]] = None,
210
+ **kwargs: Any,
211
+ ) -> [List[Document], float, List[float]]:
212
+ docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
213
+ return docs_scores_and_embeddings
214
+
215
+ def similarity_search_with_scores_and_embeddings(
216
+ self,
217
+ query: str,
218
+ k: int = DEFAULT_K,
219
+ filter: Optional[Dict[str, str]] = None,
220
+ where_document: Optional[Dict[str, str]] = None,
221
+ **kwargs: Any,
222
+ ) -> List[Tuple[Document, float, List[float]]]:
223
+
224
+ if self._embedding_function is None:
225
+ results = self.__query_collection(
226
+ query_texts=[query],
227
+ n_results=k,
228
+ where=filter,
229
+ where_document=where_document,
230
+ include=['metadatas', 'documents', 'embeddings', 'distances']
231
+ )
232
+ else:
233
+ query_embedding = self._embedding_function.embed_query(query)
234
+ results = self.__query_collection(
235
+ query_embeddings=[query_embedding],
236
+ n_results=k,
237
+ where=filter,
238
+ where_document=where_document,
239
+ include=['metadatas', 'documents', 'embeddings', 'distances']
240
+ )
241
+
242
+ return _results_to_docs_scores_and_embeddings(results)
243
+
244
+
245
+ class FAISSAdvancedRetrieval(FAISS):
246
+ pass
247
+
248
+
249
+ class NER_Retrival(VectorStore):
250
+ """
251
+ This class implement a retrieval based on NER models.
252
+ This is an alternative retrieval to embeddings that relies on extracted entities.
253
+ """
254
+ pass
255
+
256
+
257
+ engines = {
258
+ 'chroma': ChromaAdvancedRetrieval,
259
+ 'faiss': FAISSAdvancedRetrieval,
260
+ 'ner': NER_Retrival
261
+ }
262
+
263
+
264
+ class DataStorage:
265
  embeddings_dict = {}
266
  embeddings_map_from_md5 = {}
267
  embeddings_map_to_md5 = {}
268
 
269
+ def __init__(
270
+ self,
271
+ embedding_function,
272
+ root_path: Path = None,
273
+ engine=ChromaAdvancedRetrieval,
274
+ ) -> None:
275
+ self.root_path = root_path
276
+ self.engine = engine
 
 
 
 
 
 
 
277
  self.embedding_function = embedding_function
 
 
 
 
278
 
279
+ if root_path is not None:
280
+ self.embeddings_root_path = root_path
281
+ if not os.path.exists(root_path):
282
+ os.makedirs(root_path)
283
  else:
284
  self.load_embeddings(self.embeddings_root_path)
285
 
 
 
 
286
  def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
287
  """
288
+ Load the vector storage assuming they are all persisted and stored in a single directory.
289
  The root path of the embeddings containing one data store for each document in each subdirectory
290
  """
291
 
 
296
  return
297
 
298
  for embedding_document_dir in embeddings_directories:
299
+ self.embeddings_dict[embedding_document_dir.name] = self.engine(
300
+ persist_directory=embedding_document_dir.path,
301
+ embedding_function=self.embedding_function
302
+ )
303
 
304
  filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
305
  if filename_list:
 
318
  def get_filename_from_md5(self, md5):
319
  return self.embeddings_map_from_md5[md5]
320
 
321
+ def embed_document(self, doc_id, texts, metadatas):
322
+ if doc_id not in self.embeddings_dict.keys():
323
+ self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
324
+ embedding=self.embedding_function,
325
+ metadatas=metadatas,
326
+ collection_name=doc_id)
327
+ else:
328
+ # Workaround Chroma (?) breaking change
329
+ self.embeddings_dict[doc_id].delete_collection()
330
+ self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
331
+ embedding=self.embedding_function,
332
+ metadatas=metadatas,
333
+ collection_name=doc_id)
334
+
335
+ self.embeddings_root_path = None
336
+
337
+
338
+ class DocumentQAEngine:
339
+ llm = None
340
+ qa_chain_type = None
341
+
342
+ default_prompts = {
343
+ 'stuff': stuff_prompt,
344
+ 'refine': refine_prompts,
345
+ "map_reduce": map_reduce_prompt,
346
+ "map_rerank": map_rerank_prompt
347
+ }
348
+
349
+ def __init__(self,
350
+ llm,
351
+ data_storage: DataStorage,
352
+ qa_chain_type="stuff",
353
+ grobid_url=None,
354
+ memory=None
355
+ ):
356
+
357
+ self.llm = llm
358
+ self.memory = memory
359
+ self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
360
+ self.text_merger = TextMerger()
361
+ self.data_storage = data_storage
362
+
363
+ if grobid_url:
364
+ self.grobid_processor = GrobidProcessor(grobid_url)
365
+
366
+ def query_document(
367
+ self,
368
+ query: str,
369
+ doc_id,
370
+ output_parser=None,
371
+ context_size=4,
372
+ extraction_schema=None,
373
+ verbose=False
374
+ ) -> (Any, str):
375
  # self.load_embeddings(self.embeddings_root_path)
376
 
377
  if verbose:
 
400
  else:
401
  return None, response, coordinates
402
 
403
+ def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], list):
404
+ """
405
+ Returns the context related to a given query
406
+ """
407
+ documents, coordinates = self._get_context(doc_id, query, context_size)
408
 
409
  context_as_text = [doc.page_content for doc in documents]
410
+ return context_as_text, coordinates
411
 
412
  def query_storage_and_embeddings(self, query: str, doc_id, context_size=4):
413
+ """
414
+ Returns both the context and the embedding information from a given query
415
+ """
416
+ db = self.data_storage.embeddings_dict[doc_id]
417
+ retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
418
+ relevant_documents = retriever.get_relevant_documents(query)
419
 
420
  context_as_text = [doc.page_content for doc in relevant_documents]
421
  return context_as_text
 
443
 
444
  return parsed_output
445
 
446
+ def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
447
  relevant_documents = self._get_context(doc_id, query, context_size)
448
  relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
449
  for doc in
450
+ relevant_documents]
451
  response = self.chain.run(input_documents=relevant_documents,
452
  question=query)
453
 
 
455
  self.memory.save_context({"input": query}, {"output": response})
456
  return response, relevant_document_coordinates
457
 
458
+ def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
459
+ db = self.data_storage.embeddings_dict[doc_id]
460
  retriever = db.as_retriever(search_kwargs={"k": context_size})
461
  relevant_documents = retriever.get_relevant_documents(query)
462
+ relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
463
+ for doc in
464
+ relevant_documents]
465
  if self.memory and len(self.memory.buffer_as_messages) > 0:
466
  relevant_documents.append(
467
  Document(
468
  page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format(
469
  self.memory.buffer_as_str))
470
  )
471
+ return relevant_documents, relevant_document_coordinates
472
 
473
+ def get_full_context_by_document(self, doc_id):
474
+ """
475
+ Return the full context from the document
476
+ """
477
+ db = self.data_storage.embeddings_dict[doc_id]
478
  docs = db.get()
479
  return docs['documents']
480
 
481
  def _get_context_multiquery(self, doc_id, query, context_size=4):
482
+ db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
483
  multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
484
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
485
  return relevant_documents
486
 
487
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
488
  """
489
+ Extract text from documents using Grobid.
490
+ - if chunk_size is < 0, keeps each paragraph separately
491
+ - if chunk_size > 0, aggregate all paragraphs and split them again using an approximate chunk size
492
  """
493
  if verbose:
494
  print("File", pdf_file_path)
 
528
 
529
  return texts, metadatas, ids
530
 
531
+ def create_memory_embeddings(
532
+ self,
533
+ pdf_path,
534
+ doc_id=None,
535
+ chunk_size=500,
536
+ perc_overlap=0.1
537
+ ):
538
  texts, metadata, ids = self.get_text_from_document(
539
  pdf_path,
540
  chunk_size=chunk_size,
 
544
  else:
545
  hash = metadata[0]['hash']
546
 
547
+ self.data_storage.embed_document(hash, texts, metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
  return hash
550
 
551
+ def create_embeddings(
552
+ self,
553
+ pdfs_dir_path: Path,
554
+ chunk_size=500,
555
+ perc_overlap=0.1,
556
+ include_biblio=False
557
+ ):
558
  input_files = []
559
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
560
  for file_ in files:
 
566
  desc="Grobid + embeddings processing"):
567
 
568
  md5 = self.calculate_md5(input_file)
569
+ data_path = os.path.join(self.data_storage.embeddings_root_path, md5)
570
 
571
  if os.path.exists(data_path):
572
  print(data_path, "exists. Skipping it ")
573
  continue
574
+ # include = ["biblio"] if include_biblio else []
575
  texts, metadata, ids = self.get_text_from_document(
576
  input_file,
577
  chunk_size=chunk_size,
578
+ perc_overlap=perc_overlap)
 
579
  filename = metadata[0]['filename']
580
 
581
  vector_db_document = Chroma.from_texts(texts,
requirements.txt CHANGED
@@ -4,10 +4,10 @@ grobid-client-python==0.0.7
4
  grobid_tei_xml==0.1.3
5
 
6
  # Utils
7
- tqdm==4.66.1
8
  pyyaml==6.0.1
9
- pytest==7.4.3
10
- streamlit==1.29.0
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
@@ -15,13 +15,13 @@ watchdog
15
  dateparser
16
 
17
  # LLM
18
- chromadb==0.4.19
19
- tiktoken==0.4.0
20
- openai==0.27.7
21
- langchain==0.0.350
22
- langchain-core==0.1.0
23
  typing-inspect==0.9.0
24
- typing_extensions==4.8.0
25
- pydantic==2.4.2
26
- sentence_transformers==2.2.2
27
  streamlit-pdf-viewer
 
4
  grobid_tei_xml==0.1.3
5
 
6
  # Utils
7
+ tqdm==4.66.2
8
  pyyaml==6.0.1
9
+ pytest==8.1.1
10
+ streamlit==1.33.0
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
 
15
  dateparser
16
 
17
  # LLM
18
+ chromadb==0.4.24
19
+ tiktoken==0.6.0
20
+ openai==1.16.2
21
+ langchain==0.1.14
22
+ langchain-core==0.1.40
23
  typing-inspect==0.9.0
24
+ typing_extensions==4.11.0
25
+ pydantic==2.6.4
26
+ sentence_transformers==2.6.1
27
  streamlit-pdf-viewer
streamlit_app.py CHANGED
@@ -9,15 +9,16 @@ from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
10
  from streamlit_pdf_viewer import pdf_viewer
11
 
 
 
12
  dotenv.load_dotenv(override=True)
13
 
14
  import streamlit as st
15
  from langchain.chat_models import ChatOpenAI
16
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
17
 
18
- from document_qa.document_qa_engine import DocumentQAEngine
19
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
20
- from grobid_client_generic import GrobidClientGeneric
21
 
22
  OPENAI_MODELS = ['gpt-3.5-turbo',
23
  "gpt-4",
@@ -168,14 +169,15 @@ def init_qa(model, api_key=None):
168
  st.stop()
169
  return
170
 
171
- return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
 
172
 
173
 
174
  @st.cache_resource
175
  def init_ner():
176
  quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
177
 
178
- materials_client = GrobidClientGeneric(ping=True)
179
  config_materials = {
180
  'grobid': {
181
  "server": os.environ['GROBID_MATERIALS_URL'],
@@ -190,10 +192,8 @@ def init_ner():
190
 
191
  materials_client.set_config(config_materials)
192
 
193
- gqa = GrobidAggregationProcessor(None,
194
- grobid_quantities_client=quantities_client,
195
- grobid_superconductors_client=materials_client
196
- )
197
  return gqa
198
 
199
 
@@ -340,9 +340,12 @@ with st.sidebar:
340
 
341
  st.session_state['pdf_rendering'] = st.radio(
342
  "PDF rendering mode",
343
- {"PDF.JS", "Native browser engine"},
344
- index=1,
345
  disabled=not uploaded_file,
 
 
 
346
  )
347
 
348
  st.divider()
@@ -441,7 +444,8 @@ with right_column:
441
  text_response = None
442
  if mode == "Embeddings":
443
  with st.spinner("Generating LLM response..."):
444
- text_response = st.session_state['rqa'][model].query_storage_and_embeddings(question, st.session_state.doc_id,
 
445
  context_size=context_size)
446
  elif mode == "LLM":
447
  with st.spinner("Generating response..."):
@@ -449,14 +453,14 @@ with right_column:
449
  st.session_state.doc_id,
450
  context_size=context_size)
451
 
452
- annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
453
- for coord_doc in coordinates]
454
- gradients = generate_color_gradient(len(annotations))
455
- for i, color in enumerate(gradients):
456
- for annotation in annotations[i]:
457
- annotation['color'] = color
458
- st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
459
- annotation_doc]
460
 
461
  if not text_response:
462
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
@@ -486,5 +490,5 @@ with left_column:
486
  height=800,
487
  annotation_outline_size=1,
488
  annotations=st.session_state['annotations'],
489
- rendering='unwrap' if st.session_state['pdf_rendering'] == 'PDF.JS' else 'legacy_embed'
490
  )
 
9
  from langchain.memory import ConversationBufferWindowMemory
10
  from streamlit_pdf_viewer import pdf_viewer
11
 
12
+ from document_qa.ner_client_generic import NERClientGeneric
13
+
14
  dotenv.load_dotenv(override=True)
15
 
16
  import streamlit as st
17
  from langchain.chat_models import ChatOpenAI
18
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
19
 
20
+ from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
21
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
 
22
 
23
  OPENAI_MODELS = ['gpt-3.5-turbo',
24
  "gpt-4",
 
169
  st.stop()
170
  return
171
 
172
+ storage = DataStorage(embeddings)
173
+ return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
174
 
175
 
176
  @st.cache_resource
177
  def init_ner():
178
  quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
179
 
180
+ materials_client = NERClientGeneric(ping=True)
181
  config_materials = {
182
  'grobid': {
183
  "server": os.environ['GROBID_MATERIALS_URL'],
 
192
 
193
  materials_client.set_config(config_materials)
194
 
195
+ gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client,
196
+ grobid_superconductors_client=materials_client)
 
 
197
  return gqa
198
 
199
 
 
340
 
341
  st.session_state['pdf_rendering'] = st.radio(
342
  "PDF rendering mode",
343
+ ("unwrap", "legacy_embed"),
344
+ index=0,
345
  disabled=not uploaded_file,
346
+ help="PDF rendering engine."
347
+ "Note: The Legacy PDF viewer does not support annotations and might not work on Chrome.",
348
+ format_func=lambda q: "Legacy PDF Viewer" if q == "legacy_embed" else "Streamlit PDF Viewer (Pdf.js)"
349
  )
350
 
351
  st.divider()
 
444
  text_response = None
445
  if mode == "Embeddings":
446
  with st.spinner("Generating LLM response..."):
447
+ text_response, coordinates = st.session_state['rqa'][model].query_storage(question,
448
+ st.session_state.doc_id,
449
  context_size=context_size)
450
  elif mode == "LLM":
451
  with st.spinner("Generating response..."):
 
453
  st.session_state.doc_id,
454
  context_size=context_size)
455
 
456
+ annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
457
+ for coord_doc in coordinates]
458
+ gradients = generate_color_gradient(len(annotations))
459
+ for i, color in enumerate(gradients):
460
+ for annotation in annotations[i]:
461
+ annotation['color'] = color
462
+ st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
463
+ annotation_doc]
464
 
465
  if not text_response:
466
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
 
490
  height=800,
491
  annotation_outline_size=1,
492
  annotations=st.session_state['annotations'],
493
+ rendering=st.session_state['pdf_rendering']
494
  )