Spaces:
Running
Running
add merger that preserve the coordinates and aggregate them meaningfully
Browse files
document_qa/document_qa_engine.py
CHANGED
@@ -3,18 +3,89 @@ import os
|
|
3 |
from pathlib import Path
|
4 |
from typing import Union, Any
|
5 |
|
6 |
-
|
7 |
from grobid_client.grobid_client import GrobidClient
|
8 |
-
from langchain.chains import create_extraction_chain
|
9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
10 |
map_rerank_prompt
|
11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
12 |
from langchain.retrievers import MultiQueryRetriever
|
13 |
from langchain.schema import Document
|
14 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
15 |
from langchain.vectorstores import Chroma
|
16 |
from tqdm import tqdm
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class DocumentQAEngine:
|
@@ -44,6 +115,7 @@ class DocumentQAEngine:
|
|
44 |
self.llm = llm
|
45 |
self.memory = memory
|
46 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
47 |
|
48 |
if embeddings_root_path is not None:
|
49 |
self.embeddings_root_path = embeddings_root_path
|
@@ -157,7 +229,9 @@ class DocumentQAEngine:
|
|
157 |
|
158 |
def _run_query(self, doc_id, query, context_size=4):
|
159 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
160 |
-
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
|
|
|
|
161 |
response = self.chain.run(input_documents=relevant_documents,
|
162 |
question=query)
|
163 |
|
@@ -196,7 +270,7 @@ class DocumentQAEngine:
|
|
196 |
if verbose:
|
197 |
print("File", pdf_file_path)
|
198 |
filename = Path(pdf_file_path).stem
|
199 |
-
coordinates = True if chunk_size == -1 else False
|
200 |
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
|
201 |
|
202 |
biblio = structure['biblio']
|
@@ -209,29 +283,25 @@ class DocumentQAEngine:
|
|
209 |
metadatas = []
|
210 |
ids = []
|
211 |
|
212 |
-
if chunk_size
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
texts.append(passage['text'])
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
metadatas.append(biblio_copy)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
metadatas = [biblio for _ in range(len(texts))]
|
234 |
-
ids = [id for id, t in enumerate(texts)]
|
235 |
|
236 |
return texts, metadatas, ids
|
237 |
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Union, Any
|
5 |
|
6 |
+
import tiktoken
|
7 |
from grobid_client.grobid_client import GrobidClient
|
8 |
+
from langchain.chains import create_extraction_chain
|
9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
10 |
map_rerank_prompt
|
11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
12 |
from langchain.retrievers import MultiQueryRetriever
|
13 |
from langchain.schema import Document
|
|
|
14 |
from langchain.vectorstores import Chroma
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
from document_qa.grobid_processors import GrobidProcessor
|
18 |
+
|
19 |
+
|
20 |
+
class TextMerger:
|
21 |
+
def __init__(self, model_name=None, encoding_name="gpt2"):
|
22 |
+
if model_name is not None:
|
23 |
+
self.enc = tiktoken.encoding_for_model(model_name)
|
24 |
+
else:
|
25 |
+
self.enc = tiktoken.get_encoding(encoding_name)
|
26 |
+
|
27 |
+
def encode(self, text, allowed_special=set(), disallowed_special="all"):
|
28 |
+
return self.enc.encode(
|
29 |
+
text,
|
30 |
+
allowed_special=allowed_special,
|
31 |
+
disallowed_special=disallowed_special,
|
32 |
+
)
|
33 |
+
|
34 |
+
def merge_passages(self, passages, chunk_size, tolerance=0.2):
|
35 |
+
new_passages = []
|
36 |
+
new_coordinates = []
|
37 |
+
current_texts = []
|
38 |
+
current_coordinates = []
|
39 |
+
for idx, passage in enumerate(passages):
|
40 |
+
text = passage['text']
|
41 |
+
coordinates = passage['coordinates']
|
42 |
+
current_texts.append(text)
|
43 |
+
current_coordinates.append(coordinates)
|
44 |
+
|
45 |
+
accumulated_text = " ".join(current_texts)
|
46 |
+
|
47 |
+
encoded_accumulated_text = self.encode(accumulated_text)
|
48 |
+
|
49 |
+
if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
|
50 |
+
if len(current_texts) > 1:
|
51 |
+
new_passages.append(current_texts[:-1])
|
52 |
+
new_coordinates.append(current_coordinates[:-1])
|
53 |
+
current_texts = [current_texts[-1]]
|
54 |
+
current_coordinates = [current_coordinates[-1]]
|
55 |
+
else:
|
56 |
+
new_passages.append(current_texts)
|
57 |
+
new_coordinates.append(current_coordinates)
|
58 |
+
current_texts = []
|
59 |
+
current_coordinates = []
|
60 |
+
|
61 |
+
elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
|
62 |
+
new_passages.append(current_texts)
|
63 |
+
new_coordinates.append(current_coordinates)
|
64 |
+
current_texts = []
|
65 |
+
current_coordinates = []
|
66 |
+
else:
|
67 |
+
print("bao")
|
68 |
+
|
69 |
+
if len(current_texts) > 0:
|
70 |
+
new_passages.append(current_texts)
|
71 |
+
new_coordinates.append(current_coordinates)
|
72 |
+
|
73 |
+
new_passages_struct = []
|
74 |
+
for i, passages in enumerate(new_passages):
|
75 |
+
text = " ".join(passages)
|
76 |
+
coordinates = ";".join(new_coordinates[i])
|
77 |
+
|
78 |
+
new_passages_struct.append(
|
79 |
+
{
|
80 |
+
"text": text,
|
81 |
+
"coordinates": coordinates,
|
82 |
+
"type": "aggregated chunks",
|
83 |
+
"section": "mixed",
|
84 |
+
"subSection": "mixed"
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
return new_passages_struct
|
89 |
|
90 |
|
91 |
class DocumentQAEngine:
|
|
|
115 |
self.llm = llm
|
116 |
self.memory = memory
|
117 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
118 |
+
self.text_merger = TextMerger()
|
119 |
|
120 |
if embeddings_root_path is not None:
|
121 |
self.embeddings_root_path = embeddings_root_path
|
|
|
229 |
|
230 |
def _run_query(self, doc_id, query, context_size=4):
|
231 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
232 |
+
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
233 |
+
for doc in
|
234 |
+
relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
|
235 |
response = self.chain.run(input_documents=relevant_documents,
|
236 |
question=query)
|
237 |
|
|
|
270 |
if verbose:
|
271 |
print("File", pdf_file_path)
|
272 |
filename = Path(pdf_file_path).stem
|
273 |
+
coordinates = True # if chunk_size == -1 else False
|
274 |
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
|
275 |
|
276 |
biblio = structure['biblio']
|
|
|
283 |
metadatas = []
|
284 |
ids = []
|
285 |
|
286 |
+
if chunk_size > 0:
|
287 |
+
new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
|
288 |
+
else:
|
289 |
+
new_passages = structure['passages']
|
|
|
290 |
|
291 |
+
for passage in new_passages:
|
292 |
+
biblio_copy = copy.copy(biblio)
|
293 |
+
if len(str.strip(passage['text'])) > 0:
|
294 |
+
texts.append(passage['text'])
|
|
|
295 |
|
296 |
+
biblio_copy['type'] = passage['type']
|
297 |
+
biblio_copy['section'] = passage['section']
|
298 |
+
biblio_copy['subSection'] = passage['subSection']
|
299 |
+
biblio_copy['coordinates'] = passage['coordinates']
|
300 |
+
metadatas.append(biblio_copy)
|
301 |
+
|
302 |
+
# ids.append(passage['passage_id'])
|
303 |
+
|
304 |
+
ids = [id for id, t in enumerate(new_passages)]
|
|
|
|
|
305 |
|
306 |
return texts, metadatas, ids
|
307 |
|
tests/test_document_qa_engine.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from document_qa.document_qa_engine import TextMerger
|
2 |
+
|
3 |
+
|
4 |
+
def test_merge_passages_small_chunk():
|
5 |
+
merger = TextMerger()
|
6 |
+
|
7 |
+
passages = [
|
8 |
+
{
|
9 |
+
'text': "The quick brown fox jumps over the tree",
|
10 |
+
'coordinates': '1'
|
11 |
+
},
|
12 |
+
{
|
13 |
+
'text': "and went straight into the mouth of a bear.",
|
14 |
+
'coordinates': '2'
|
15 |
+
},
|
16 |
+
{
|
17 |
+
'text': "The color of the colors is a color with colors",
|
18 |
+
'coordinates': '3'
|
19 |
+
},
|
20 |
+
{
|
21 |
+
'text': "the main colors are not the colorw we show",
|
22 |
+
'coordinates': '4'
|
23 |
+
}
|
24 |
+
]
|
25 |
+
new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)
|
26 |
+
|
27 |
+
assert len(new_passages) == 4
|
28 |
+
assert new_passages[0]['coordinates'] == "1"
|
29 |
+
assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"
|
30 |
+
|
31 |
+
assert new_passages[1]['coordinates'] == "2"
|
32 |
+
assert new_passages[1]['text'] == "and went straight into the mouth of a bear."
|
33 |
+
|
34 |
+
assert new_passages[2]['coordinates'] == "3"
|
35 |
+
assert new_passages[2]['text'] == "The color of the colors is a color with colors"
|
36 |
+
|
37 |
+
assert new_passages[3]['coordinates'] == "4"
|
38 |
+
assert new_passages[3]['text'] == "the main colors are not the colorw we show"
|
39 |
+
|
40 |
+
|
41 |
+
def test_merge_passages_big_chunk():
|
42 |
+
merger = TextMerger()
|
43 |
+
|
44 |
+
passages = [
|
45 |
+
{
|
46 |
+
'text': "The quick brown fox jumps over the tree",
|
47 |
+
'coordinates': '1'
|
48 |
+
},
|
49 |
+
{
|
50 |
+
'text': "and went straight into the mouth of a bear.",
|
51 |
+
'coordinates': '2'
|
52 |
+
},
|
53 |
+
{
|
54 |
+
'text': "The color of the colors is a color with colors",
|
55 |
+
'coordinates': '3'
|
56 |
+
},
|
57 |
+
{
|
58 |
+
'text': "the main colors are not the colorw we show",
|
59 |
+
'coordinates': '4'
|
60 |
+
}
|
61 |
+
]
|
62 |
+
new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)
|
63 |
+
|
64 |
+
assert len(new_passages) == 2
|
65 |
+
assert new_passages[0]['coordinates'] == "1;2"
|
66 |
+
assert new_passages[0][
|
67 |
+
'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."
|
68 |
+
|
69 |
+
assert new_passages[1]['coordinates'] == "3;4"
|
70 |
+
assert new_passages[1][
|
71 |
+
'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"
|