lfoppiano commited on
Commit
c07b97b
1 Parent(s): aeb450e

add merger that preserve the coordinates and aggregate them meaningfully

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -3,18 +3,89 @@ import os
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
- from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
- from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
14
- from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain.vectorstores import Chroma
16
  from tqdm import tqdm
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class DocumentQAEngine:
@@ -44,6 +115,7 @@ class DocumentQAEngine:
44
  self.llm = llm
45
  self.memory = memory
46
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
 
47
 
48
  if embeddings_root_path is not None:
49
  self.embeddings_root_path = embeddings_root_path
@@ -157,7 +229,9 @@ class DocumentQAEngine:
157
 
158
  def _run_query(self, doc_id, query, context_size=4):
159
  relevant_documents = self._get_context(doc_id, query, context_size)
160
- relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in relevant_documents] #filter(lambda d: d['type'] == "sentence", relevant_documents)]
 
 
161
  response = self.chain.run(input_documents=relevant_documents,
162
  question=query)
163
 
@@ -196,7 +270,7 @@ class DocumentQAEngine:
196
  if verbose:
197
  print("File", pdf_file_path)
198
  filename = Path(pdf_file_path).stem
199
- coordinates = True if chunk_size == -1 else False
200
  structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
201
 
202
  biblio = structure['biblio']
@@ -209,29 +283,25 @@ class DocumentQAEngine:
209
  metadatas = []
210
  ids = []
211
 
212
- if chunk_size < 0:
213
- for passage in structure['passages']:
214
- biblio_copy = copy.copy(biblio)
215
- if len(str.strip(passage['text'])) > 0:
216
- texts.append(passage['text'])
217
 
218
- biblio_copy['type'] = passage['type']
219
- biblio_copy['section'] = passage['section']
220
- biblio_copy['subSection'] = passage['subSection']
221
- biblio_copy['coordinates'] = passage['coordinates']
222
- metadatas.append(biblio_copy)
223
 
224
- ids.append(passage['passage_id'])
225
- else:
226
- document_text = " ".join([passage['text'] for passage in structure['passages']])
227
- # text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
228
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
229
- chunk_size=chunk_size,
230
- chunk_overlap=chunk_size * perc_overlap
231
- )
232
- texts = text_splitter.split_text(document_text)
233
- metadatas = [biblio for _ in range(len(texts))]
234
- ids = [id for id, t in enumerate(texts)]
235
 
236
  return texts, metadatas, ids
237
 
 
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
+ import tiktoken
7
  from grobid_client.grobid_client import GrobidClient
8
+ from langchain.chains import create_extraction_chain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
 
14
  from langchain.vectorstores import Chroma
15
  from tqdm import tqdm
16
 
17
+ from document_qa.grobid_processors import GrobidProcessor
18
+
19
+
20
+ class TextMerger:
21
+ def __init__(self, model_name=None, encoding_name="gpt2"):
22
+ if model_name is not None:
23
+ self.enc = tiktoken.encoding_for_model(model_name)
24
+ else:
25
+ self.enc = tiktoken.get_encoding(encoding_name)
26
+
27
+ def encode(self, text, allowed_special=set(), disallowed_special="all"):
28
+ return self.enc.encode(
29
+ text,
30
+ allowed_special=allowed_special,
31
+ disallowed_special=disallowed_special,
32
+ )
33
+
34
+ def merge_passages(self, passages, chunk_size, tolerance=0.2):
35
+ new_passages = []
36
+ new_coordinates = []
37
+ current_texts = []
38
+ current_coordinates = []
39
+ for idx, passage in enumerate(passages):
40
+ text = passage['text']
41
+ coordinates = passage['coordinates']
42
+ current_texts.append(text)
43
+ current_coordinates.append(coordinates)
44
+
45
+ accumulated_text = " ".join(current_texts)
46
+
47
+ encoded_accumulated_text = self.encode(accumulated_text)
48
+
49
+ if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
50
+ if len(current_texts) > 1:
51
+ new_passages.append(current_texts[:-1])
52
+ new_coordinates.append(current_coordinates[:-1])
53
+ current_texts = [current_texts[-1]]
54
+ current_coordinates = [current_coordinates[-1]]
55
+ else:
56
+ new_passages.append(current_texts)
57
+ new_coordinates.append(current_coordinates)
58
+ current_texts = []
59
+ current_coordinates = []
60
+
61
+ elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
62
+ new_passages.append(current_texts)
63
+ new_coordinates.append(current_coordinates)
64
+ current_texts = []
65
+ current_coordinates = []
66
+ else:
67
+ print("bao")
68
+
69
+ if len(current_texts) > 0:
70
+ new_passages.append(current_texts)
71
+ new_coordinates.append(current_coordinates)
72
+
73
+ new_passages_struct = []
74
+ for i, passages in enumerate(new_passages):
75
+ text = " ".join(passages)
76
+ coordinates = ";".join(new_coordinates[i])
77
+
78
+ new_passages_struct.append(
79
+ {
80
+ "text": text,
81
+ "coordinates": coordinates,
82
+ "type": "aggregated chunks",
83
+ "section": "mixed",
84
+ "subSection": "mixed"
85
+ }
86
+ )
87
+
88
+ return new_passages_struct
89
 
90
 
91
  class DocumentQAEngine:
 
115
  self.llm = llm
116
  self.memory = memory
117
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
118
+ self.text_merger = TextMerger()
119
 
120
  if embeddings_root_path is not None:
121
  self.embeddings_root_path = embeddings_root_path
 
229
 
230
  def _run_query(self, doc_id, query, context_size=4):
231
  relevant_documents = self._get_context(doc_id, query, context_size)
232
+ relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
233
+ for doc in
234
+ relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
235
  response = self.chain.run(input_documents=relevant_documents,
236
  question=query)
237
 
 
270
  if verbose:
271
  print("File", pdf_file_path)
272
  filename = Path(pdf_file_path).stem
273
+ coordinates = True # if chunk_size == -1 else False
274
  structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
275
 
276
  biblio = structure['biblio']
 
283
  metadatas = []
284
  ids = []
285
 
286
+ if chunk_size > 0:
287
+ new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
288
+ else:
289
+ new_passages = structure['passages']
 
290
 
291
+ for passage in new_passages:
292
+ biblio_copy = copy.copy(biblio)
293
+ if len(str.strip(passage['text'])) > 0:
294
+ texts.append(passage['text'])
 
295
 
296
+ biblio_copy['type'] = passage['type']
297
+ biblio_copy['section'] = passage['section']
298
+ biblio_copy['subSection'] = passage['subSection']
299
+ biblio_copy['coordinates'] = passage['coordinates']
300
+ metadatas.append(biblio_copy)
301
+
302
+ # ids.append(passage['passage_id'])
303
+
304
+ ids = [id for id, t in enumerate(new_passages)]
 
 
305
 
306
  return texts, metadatas, ids
307
 
tests/test_document_qa_engine.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from document_qa.document_qa_engine import TextMerger
2
+
3
+
4
+ def test_merge_passages_small_chunk():
5
+ merger = TextMerger()
6
+
7
+ passages = [
8
+ {
9
+ 'text': "The quick brown fox jumps over the tree",
10
+ 'coordinates': '1'
11
+ },
12
+ {
13
+ 'text': "and went straight into the mouth of a bear.",
14
+ 'coordinates': '2'
15
+ },
16
+ {
17
+ 'text': "The color of the colors is a color with colors",
18
+ 'coordinates': '3'
19
+ },
20
+ {
21
+ 'text': "the main colors are not the colorw we show",
22
+ 'coordinates': '4'
23
+ }
24
+ ]
25
+ new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)
26
+
27
+ assert len(new_passages) == 4
28
+ assert new_passages[0]['coordinates'] == "1"
29
+ assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"
30
+
31
+ assert new_passages[1]['coordinates'] == "2"
32
+ assert new_passages[1]['text'] == "and went straight into the mouth of a bear."
33
+
34
+ assert new_passages[2]['coordinates'] == "3"
35
+ assert new_passages[2]['text'] == "The color of the colors is a color with colors"
36
+
37
+ assert new_passages[3]['coordinates'] == "4"
38
+ assert new_passages[3]['text'] == "the main colors are not the colorw we show"
39
+
40
+
41
+ def test_merge_passages_big_chunk():
42
+ merger = TextMerger()
43
+
44
+ passages = [
45
+ {
46
+ 'text': "The quick brown fox jumps over the tree",
47
+ 'coordinates': '1'
48
+ },
49
+ {
50
+ 'text': "and went straight into the mouth of a bear.",
51
+ 'coordinates': '2'
52
+ },
53
+ {
54
+ 'text': "The color of the colors is a color with colors",
55
+ 'coordinates': '3'
56
+ },
57
+ {
58
+ 'text': "the main colors are not the colorw we show",
59
+ 'coordinates': '4'
60
+ }
61
+ ]
62
+ new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)
63
+
64
+ assert len(new_passages) == 2
65
+ assert new_passages[0]['coordinates'] == "1;2"
66
+ assert new_passages[0][
67
+ 'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."
68
+
69
+ assert new_passages[1]['coordinates'] == "3;4"
70
+ assert new_passages[1][
71
+ 'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"