Kevin Hu
commited on
Commit
·
47ec63e
1
Parent(s):
93f905e
Light GraphRAG (#4585)
Browse files### What problem does this PR solve?
#4543
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This view is limited to 50 files because it contains too many changes.
See raw diff
- api/apps/chunk_app.py +12 -3
- api/apps/conversation_app.py +1 -1
- api/apps/kb_app.py +35 -1
- api/apps/sdk/dify_retrieval.py +13 -3
- api/apps/sdk/doc.py +13 -4
- api/db/init_data.py +1 -1
- api/db/services/dialog_service.py +9 -2
- api/db/services/document_service.py +46 -16
- api/db/services/file_service.py +1 -1
- api/db/services/task_service.py +20 -10
- api/utils/api_utils.py +1 -1
- conf/infinity_mapping.json +11 -1
- graphrag/description_summary.py +0 -146
- graphrag/entity_resolution.py +44 -40
- graphrag/extractor.py +0 -34
- graphrag/general/__init__.py +0 -0
- graphrag/{claim_extractor.py → general/claim_extractor.py} +2 -2
- graphrag/{claim_prompt.py → general/claim_prompt.py} +0 -0
- graphrag/{community_report_prompt.py → general/community_report_prompt.py} +0 -0
- graphrag/{community_reports_extractor.py → general/community_reports_extractor.py} +25 -14
- graphrag/{entity_embedding.py → general/entity_embedding.py} +1 -1
- graphrag/general/extractor.py +245 -0
- graphrag/general/graph_extractor.py +154 -0
- graphrag/{graph_prompt.py → general/graph_prompt.py} +16 -1
- graphrag/general/index.py +197 -0
- graphrag/{leiden.py → general/leiden.py} +2 -1
- graphrag/{mind_map_extractor.py → general/mind_map_extractor.py} +2 -2
- graphrag/{mind_map_prompt.py → general/mind_map_prompt.py} +0 -0
- graphrag/general/smoke.py +63 -0
- graphrag/graph_extractor.py +0 -322
- graphrag/index.py +0 -153
- graphrag/light/__init__.py +0 -0
- graphrag/light/graph_extractor.py +127 -0
- graphrag/light/graph_prompt.py +255 -0
- graphrag/{smoke.py → light/smoke.py} +28 -25
- graphrag/query_analyze_prompt.py +218 -0
- graphrag/search.py +301 -78
- graphrag/utils.py +386 -0
- pyproject.toml +2 -1
- rag/app/book.py +3 -3
- rag/app/email.py +1 -1
- rag/app/knowledge_graph.py +0 -48
- rag/app/laws.py +3 -3
- rag/app/manual.py +3 -3
- rag/app/naive.py +6 -3
- rag/app/one.py +3 -3
- rag/app/paper.py +1 -1
- rag/app/presentation.py +3 -3
- rag/llm/chat_model.py +1 -0
- rag/nlp/search.py +8 -1
api/apps/chunk_app.py
CHANGED
@@ -155,7 +155,7 @@ def set():
|
|
155 |
r"[\n\t]",
|
156 |
req["content_with_weight"]) if len(t) > 1]
|
157 |
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
158 |
-
d = beAdoc(d,
|
159 |
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
160 |
|
161 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
@@ -270,6 +270,7 @@ def retrieval_test():
|
|
270 |
doc_ids = req.get("doc_ids", [])
|
271 |
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
272 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
|
|
273 |
top = int(req.get("top_k", 1024))
|
274 |
tenant_ids = []
|
275 |
|
@@ -301,12 +302,20 @@ def retrieval_test():
|
|
301 |
question += keyword_extraction(chat_mdl, question)
|
302 |
|
303 |
labels = label_question(question, [kb])
|
304 |
-
|
305 |
-
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
306 |
similarity_threshold, vector_similarity_weight, top,
|
307 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
308 |
rank_feature=labels
|
309 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
for c in ranks["chunks"]:
|
311 |
c.pop("vector", None)
|
312 |
ranks["labels"] = labels
|
|
|
155 |
r"[\n\t]",
|
156 |
req["content_with_weight"]) if len(t) > 1]
|
157 |
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
158 |
+
d = beAdoc(d, q, a, not any(
|
159 |
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
160 |
|
161 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
|
|
270 |
doc_ids = req.get("doc_ids", [])
|
271 |
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
272 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
273 |
+
use_kg = req.get("use_kg", False)
|
274 |
top = int(req.get("top_k", 1024))
|
275 |
tenant_ids = []
|
276 |
|
|
|
302 |
question += keyword_extraction(chat_mdl, question)
|
303 |
|
304 |
labels = label_question(question, [kb])
|
305 |
+
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
|
|
306 |
similarity_threshold, vector_similarity_weight, top,
|
307 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
308 |
rank_feature=labels
|
309 |
)
|
310 |
+
if use_kg:
|
311 |
+
ck = settings.kg_retrievaler.retrieval(question,
|
312 |
+
tenant_ids,
|
313 |
+
kb_ids,
|
314 |
+
embd_mdl,
|
315 |
+
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
316 |
+
if ck["content_with_weight"]:
|
317 |
+
ranks["chunks"].insert(0, ck)
|
318 |
+
|
319 |
for c in ranks["chunks"]:
|
320 |
c.pop("vector", None)
|
321 |
ranks["labels"] = labels
|
api/apps/conversation_app.py
CHANGED
@@ -31,7 +31,7 @@ from api.db.services.llm_service import LLMBundle, TenantService
|
|
31 |
from api import settings
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
34 |
-
from graphrag.mind_map_extractor import MindMapExtractor
|
35 |
|
36 |
|
37 |
@manager.route('/set', methods=['POST']) # noqa: F821
|
|
|
31 |
from api import settings
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
34 |
+
from graphrag.general.mind_map_extractor import MindMapExtractor
|
35 |
|
36 |
|
37 |
@manager.route('/set', methods=['POST']) # noqa: F821
|
api/apps/kb_app.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
|
|
16 |
from flask import request
|
17 |
from flask_login import login_required, current_user
|
18 |
|
@@ -272,4 +274,36 @@ def rename_tags(kb_id):
|
|
272 |
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
273 |
search.index_name(kb.tenant_id),
|
274 |
kb_id)
|
275 |
-
return get_json_result(data=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import json
|
17 |
+
|
18 |
from flask import request
|
19 |
from flask_login import login_required, current_user
|
20 |
|
|
|
274 |
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
275 |
search.index_name(kb.tenant_id),
|
276 |
kb_id)
|
277 |
+
return get_json_result(data=True)
|
278 |
+
|
279 |
+
|
280 |
+
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
281 |
+
@login_required
|
282 |
+
def knowledge_graph(kb_id):
|
283 |
+
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
284 |
+
return get_json_result(
|
285 |
+
data=False,
|
286 |
+
message='No authorization.',
|
287 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
288 |
+
)
|
289 |
+
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
290 |
+
req = {
|
291 |
+
"kb_id": [kb_id],
|
292 |
+
"knowledge_graph_kwd": ["graph"]
|
293 |
+
}
|
294 |
+
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
|
295 |
+
obj = {"graph": {}, "mind_map": {}}
|
296 |
+
for id in sres.ids[:1]:
|
297 |
+
ty = sres.field[id]["knowledge_graph_kwd"]
|
298 |
+
try:
|
299 |
+
content_json = json.loads(sres.field[id]["content_with_weight"])
|
300 |
+
except Exception:
|
301 |
+
continue
|
302 |
+
|
303 |
+
obj[ty] = content_json
|
304 |
+
|
305 |
+
if "nodes" in obj["graph"]:
|
306 |
+
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
307 |
+
if "edges" in obj["graph"]:
|
308 |
+
obj["graph"]["edges"] = sorted(obj["graph"]["edges"], key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
309 |
+
return get_json_result(data=obj)
|
api/apps/sdk/dify_retrieval.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
#
|
16 |
from flask import request, jsonify
|
17 |
|
18 |
-
from api.db import LLMType
|
19 |
from api.db.services.dialog_service import label_question
|
20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
21 |
from api.db.services.llm_service import LLMBundle
|
@@ -30,6 +30,7 @@ def retrieval(tenant_id):
|
|
30 |
req = request.json
|
31 |
question = req["query"]
|
32 |
kb_id = req["knowledge_id"]
|
|
|
33 |
retrieval_setting = req.get("retrieval_setting", {})
|
34 |
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
35 |
top = int(retrieval_setting.get("top_k", 1024))
|
@@ -45,8 +46,7 @@ def retrieval(tenant_id):
|
|
45 |
|
46 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
47 |
|
48 |
-
|
49 |
-
ranks = retr.retrieval(
|
50 |
question,
|
51 |
embd_mdl,
|
52 |
kb.tenant_id,
|
@@ -58,6 +58,16 @@ def retrieval(tenant_id):
|
|
58 |
top=top,
|
59 |
rank_feature=label_question(question, [kb])
|
60 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
records = []
|
62 |
for c in ranks["chunks"]:
|
63 |
c.pop("vector", None)
|
|
|
15 |
#
|
16 |
from flask import request, jsonify
|
17 |
|
18 |
+
from api.db import LLMType
|
19 |
from api.db.services.dialog_service import label_question
|
20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
21 |
from api.db.services.llm_service import LLMBundle
|
|
|
30 |
req = request.json
|
31 |
question = req["query"]
|
32 |
kb_id = req["knowledge_id"]
|
33 |
+
use_kg = req.get("use_kg", False)
|
34 |
retrieval_setting = req.get("retrieval_setting", {})
|
35 |
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
36 |
top = int(retrieval_setting.get("top_k", 1024))
|
|
|
46 |
|
47 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
48 |
|
49 |
+
ranks = settings.retrievaler.retrieval(
|
|
|
50 |
question,
|
51 |
embd_mdl,
|
52 |
kb.tenant_id,
|
|
|
58 |
top=top,
|
59 |
rank_feature=label_question(question, [kb])
|
60 |
)
|
61 |
+
|
62 |
+
if use_kg:
|
63 |
+
ck = settings.kg_retrievaler.retrieval(question,
|
64 |
+
[tenant_id],
|
65 |
+
[kb_id],
|
66 |
+
embd_mdl,
|
67 |
+
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
68 |
+
if ck["content_with_weight"]:
|
69 |
+
ranks["chunks"].insert(0, ck)
|
70 |
+
|
71 |
records = []
|
72 |
for c in ranks["chunks"]:
|
73 |
c.pop("vector", None)
|
api/apps/sdk/doc.py
CHANGED
@@ -1297,15 +1297,15 @@ def retrieval_test(tenant_id):
|
|
1297 |
kb_ids = req["dataset_ids"]
|
1298 |
if not isinstance(kb_ids, list):
|
1299 |
return get_error_data_result("`dataset_ids` should be a list")
|
1300 |
-
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
1301 |
for id in kb_ids:
|
1302 |
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
|
1303 |
return get_error_data_result(f"You don't own the dataset {id}.")
|
|
|
1304 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
1305 |
if len(embd_nms) != 1:
|
1306 |
return get_result(
|
1307 |
message='Datasets use different embedding models."',
|
1308 |
-
code=settings.RetCode.
|
1309 |
)
|
1310 |
if "question" not in req:
|
1311 |
return get_error_data_result("`question` is required.")
|
@@ -1313,6 +1313,7 @@ def retrieval_test(tenant_id):
|
|
1313 |
size = int(req.get("page_size", 30))
|
1314 |
question = req["question"]
|
1315 |
doc_ids = req.get("document_ids", [])
|
|
|
1316 |
if not isinstance(doc_ids, list):
|
1317 |
return get_error_data_result("`documents` should be a list")
|
1318 |
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
@@ -1342,8 +1343,7 @@ def retrieval_test(tenant_id):
|
|
1342 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
1343 |
question += keyword_extraction(chat_mdl, question)
|
1344 |
|
1345 |
-
|
1346 |
-
ranks = retr.retrieval(
|
1347 |
question,
|
1348 |
embd_mdl,
|
1349 |
kb.tenant_id,
|
@@ -1358,6 +1358,15 @@ def retrieval_test(tenant_id):
|
|
1358 |
highlight=highlight,
|
1359 |
rank_feature=label_question(question, kbs)
|
1360 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1361 |
for c in ranks["chunks"]:
|
1362 |
c.pop("vector", None)
|
1363 |
|
|
|
1297 |
kb_ids = req["dataset_ids"]
|
1298 |
if not isinstance(kb_ids, list):
|
1299 |
return get_error_data_result("`dataset_ids` should be a list")
|
|
|
1300 |
for id in kb_ids:
|
1301 |
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
|
1302 |
return get_error_data_result(f"You don't own the dataset {id}.")
|
1303 |
+
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
1304 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
1305 |
if len(embd_nms) != 1:
|
1306 |
return get_result(
|
1307 |
message='Datasets use different embedding models."',
|
1308 |
+
code=settings.RetCode.DATA_ERROR,
|
1309 |
)
|
1310 |
if "question" not in req:
|
1311 |
return get_error_data_result("`question` is required.")
|
|
|
1313 |
size = int(req.get("page_size", 30))
|
1314 |
question = req["question"]
|
1315 |
doc_ids = req.get("document_ids", [])
|
1316 |
+
use_kg = req.get("use_kg", False)
|
1317 |
if not isinstance(doc_ids, list):
|
1318 |
return get_error_data_result("`documents` should be a list")
|
1319 |
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
|
|
1343 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
1344 |
question += keyword_extraction(chat_mdl, question)
|
1345 |
|
1346 |
+
ranks = settings.retrievaler.retrieval(
|
|
|
1347 |
question,
|
1348 |
embd_mdl,
|
1349 |
kb.tenant_id,
|
|
|
1358 |
highlight=highlight,
|
1359 |
rank_feature=label_question(question, kbs)
|
1360 |
)
|
1361 |
+
if use_kg:
|
1362 |
+
ck = settings.kg_retrievaler.retrieval(question,
|
1363 |
+
[k.tenant_id for k in kbs],
|
1364 |
+
kb_ids,
|
1365 |
+
embd_mdl,
|
1366 |
+
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
1367 |
+
if ck["content_with_weight"]:
|
1368 |
+
ranks["chunks"].insert(0, ck)
|
1369 |
+
|
1370 |
for c in ranks["chunks"]:
|
1371 |
c.pop("vector", None)
|
1372 |
|
api/db/init_data.py
CHANGED
@@ -133,7 +133,7 @@ def init_llm_factory():
|
|
133 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
134 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
135 |
TenantService.filter_update([1 == 1], {
|
136 |
-
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,
|
137 |
## insert openai two embedding models to the current openai user.
|
138 |
# print("Start to insert 2 OpenAI embedding models...")
|
139 |
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
|
|
133 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
134 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
135 |
TenantService.filter_update([1 == 1], {
|
136 |
+
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
|
137 |
## insert openai two embedding models to the current openai user.
|
138 |
# print("Start to insert 2 OpenAI embedding models...")
|
139 |
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
api/db/services/dialog_service.py
CHANGED
@@ -197,8 +197,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
197 |
|
198 |
embedding_model_name = embedding_list[0]
|
199 |
|
200 |
-
|
201 |
-
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
202 |
|
203 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
204 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
@@ -275,6 +274,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
275 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
276 |
rank_feature=label_question(" ".join(questions), kbs)
|
277 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
retrieval_ts = timer()
|
280 |
|
|
|
197 |
|
198 |
embedding_model_name = embedding_list[0]
|
199 |
|
200 |
+
retriever = settings.retrievaler
|
|
|
201 |
|
202 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
203 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
|
274 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
275 |
rank_feature=label_question(" ".join(questions), kbs)
|
276 |
)
|
277 |
+
if prompt_config.get("use_kg"):
|
278 |
+
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
279 |
+
tenant_ids,
|
280 |
+
dialog.kb_ids,
|
281 |
+
embd_mdl,
|
282 |
+
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
283 |
+
if ck["content_with_weight"]:
|
284 |
+
kbinfos["chunks"].insert(0, ck)
|
285 |
|
286 |
retrieval_ts = timer()
|
287 |
|
api/db/services/document_service.py
CHANGED
@@ -28,7 +28,7 @@ from peewee import fn
|
|
28 |
from api.db.db_utils import bulk_insert_into_db
|
29 |
from api import settings
|
30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
31 |
-
from graphrag.mind_map_extractor import MindMapExtractor
|
32 |
from rag.settings import SVR_QUEUE_NAME
|
33 |
from rag.utils.storage_factory import STORAGE_IMPL
|
34 |
from rag.nlp import search, rag_tokenizer
|
@@ -105,8 +105,19 @@ class DocumentService(CommonService):
|
|
105 |
@classmethod
|
106 |
@DB.connection_context()
|
107 |
def remove_document(cls, doc, tenant_id):
|
108 |
-
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
109 |
cls.clear_chunk_num(doc.id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
return cls.delete_by_id(doc.id)
|
111 |
|
112 |
@classmethod
|
@@ -142,7 +153,7 @@ class DocumentService(CommonService):
|
|
142 |
@DB.connection_context()
|
143 |
def get_unfinished_docs(cls):
|
144 |
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
145 |
-
cls.model.run]
|
146 |
docs = cls.model.select(*fields) \
|
147 |
.where(
|
148 |
cls.model.status == StatusEnum.VALID.value,
|
@@ -295,9 +306,9 @@ class DocumentService(CommonService):
|
|
295 |
Tenant.asr_id,
|
296 |
Tenant.llm_id,
|
297 |
)
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
)
|
302 |
configs = configs.dicts()
|
303 |
if not configs:
|
@@ -365,6 +376,12 @@ class DocumentService(CommonService):
|
|
365 |
@classmethod
|
366 |
@DB.connection_context()
|
367 |
def update_progress(cls):
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
docs = cls.get_unfinished_docs()
|
369 |
for d in docs:
|
370 |
try:
|
@@ -390,15 +407,27 @@ class DocumentService(CommonService):
|
|
390 |
prg = -1
|
391 |
status = TaskStatus.FAIL.value
|
392 |
elif finished:
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
397 |
-
msg.append("------ RAPTOR -------")
|
398 |
else:
|
399 |
status = TaskStatus.DONE.value
|
400 |
|
401 |
-
msg = "\n".join(msg)
|
402 |
info = {
|
403 |
"process_duation": datetime.timestamp(
|
404 |
datetime.now()) -
|
@@ -430,7 +459,7 @@ class DocumentService(CommonService):
|
|
430 |
return False
|
431 |
|
432 |
|
433 |
-
def
|
434 |
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
435 |
hasher = xxhash.xxh64()
|
436 |
for field in sorted(chunking_config.keys()):
|
@@ -443,15 +472,16 @@ def queue_raptor_tasks(doc):
|
|
443 |
"doc_id": doc["id"],
|
444 |
"from_page": 100000000,
|
445 |
"to_page": 100000000,
|
446 |
-
"progress_msg":
|
447 |
}
|
448 |
|
449 |
task = new_task()
|
450 |
for field in ["doc_id", "from_page", "to_page"]:
|
451 |
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
|
|
452 |
task["digest"] = hasher.hexdigest()
|
453 |
bulk_insert_into_db(Task, [task], True)
|
454 |
-
task["
|
455 |
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
456 |
|
457 |
|
@@ -489,7 +519,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
489 |
ParserType.AUDIO.value: audio,
|
490 |
ParserType.EMAIL.value: email
|
491 |
}
|
492 |
-
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize":
|
493 |
exe = ThreadPoolExecutor(max_workers=12)
|
494 |
threads = []
|
495 |
doc_nm = {}
|
@@ -592,4 +622,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
592 |
DocumentService.increment_chunk_num(
|
593 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
594 |
|
595 |
-
return [d["id"] for d, _ in files]
|
|
|
28 |
from api.db.db_utils import bulk_insert_into_db
|
29 |
from api import settings
|
30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
31 |
+
from graphrag.general.mind_map_extractor import MindMapExtractor
|
32 |
from rag.settings import SVR_QUEUE_NAME
|
33 |
from rag.utils.storage_factory import STORAGE_IMPL
|
34 |
from rag.nlp import search, rag_tokenizer
|
|
|
105 |
@classmethod
|
106 |
@DB.connection_context()
|
107 |
def remove_document(cls, doc, tenant_id):
|
|
|
108 |
cls.clear_chunk_num(doc.id)
|
109 |
+
try:
|
110 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
111 |
+
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
|
112 |
+
{"remove": {"source_id": doc.id}},
|
113 |
+
search.index_name(tenant_id), doc.kb_id)
|
114 |
+
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
115 |
+
{"removed_kwd": "Y"},
|
116 |
+
search.index_name(tenant_id), doc.kb_id)
|
117 |
+
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
|
118 |
+
search.index_name(tenant_id), doc.kb_id)
|
119 |
+
except Exception:
|
120 |
+
pass
|
121 |
return cls.delete_by_id(doc.id)
|
122 |
|
123 |
@classmethod
|
|
|
153 |
@DB.connection_context()
|
154 |
def get_unfinished_docs(cls):
|
155 |
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
156 |
+
cls.model.run, cls.model.parser_id]
|
157 |
docs = cls.model.select(*fields) \
|
158 |
.where(
|
159 |
cls.model.status == StatusEnum.VALID.value,
|
|
|
306 |
Tenant.asr_id,
|
307 |
Tenant.llm_id,
|
308 |
)
|
309 |
+
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
310 |
+
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
311 |
+
.where(cls.model.id == doc_id)
|
312 |
)
|
313 |
configs = configs.dicts()
|
314 |
if not configs:
|
|
|
376 |
@classmethod
|
377 |
@DB.connection_context()
|
378 |
def update_progress(cls):
|
379 |
+
MSG = {
|
380 |
+
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
|
381 |
+
"graphrag": "Start Graph Extraction",
|
382 |
+
"graph_resolution": "Start Graph Resolution",
|
383 |
+
"graph_community": "Start Graph Community Reports Generation"
|
384 |
+
}
|
385 |
docs = cls.get_unfinished_docs()
|
386 |
for d in docs:
|
387 |
try:
|
|
|
407 |
prg = -1
|
408 |
status = TaskStatus.FAIL.value
|
409 |
elif finished:
|
410 |
+
m = "\n".join(sorted(msg))
|
411 |
+
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
|
412 |
+
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
|
413 |
+
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
414 |
+
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
|
415 |
+
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
|
416 |
+
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
417 |
+
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
418 |
+
and d["parser_config"].get("graphrag", {}).get("resolution") \
|
419 |
+
and m.find(MSG["graph_resolution"]) < 0:
|
420 |
+
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
|
421 |
+
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
422 |
+
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
423 |
+
and d["parser_config"].get("graphrag", {}).get("community") \
|
424 |
+
and m.find(MSG["graph_community"]) < 0:
|
425 |
+
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
|
426 |
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
|
|
427 |
else:
|
428 |
status = TaskStatus.DONE.value
|
429 |
|
430 |
+
msg = "\n".join(sorted(msg))
|
431 |
info = {
|
432 |
"process_duation": datetime.timestamp(
|
433 |
datetime.now()) -
|
|
|
459 |
return False
|
460 |
|
461 |
|
462 |
+
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
463 |
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
464 |
hasher = xxhash.xxh64()
|
465 |
for field in sorted(chunking_config.keys()):
|
|
|
472 |
"doc_id": doc["id"],
|
473 |
"from_page": 100000000,
|
474 |
"to_page": 100000000,
|
475 |
+
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
|
476 |
}
|
477 |
|
478 |
task = new_task()
|
479 |
for field in ["doc_id", "from_page", "to_page"]:
|
480 |
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
481 |
+
hasher.update(ty.encode("utf-8"))
|
482 |
task["digest"] = hasher.hexdigest()
|
483 |
bulk_insert_into_db(Task, [task], True)
|
484 |
+
task["task_type"] = ty
|
485 |
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
486 |
|
487 |
|
|
|
519 |
ParserType.AUDIO.value: audio,
|
520 |
ParserType.EMAIL.value: email
|
521 |
}
|
522 |
+
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
523 |
exe = ThreadPoolExecutor(max_workers=12)
|
524 |
threads = []
|
525 |
doc_nm = {}
|
|
|
622 |
DocumentService.increment_chunk_num(
|
623 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
624 |
|
625 |
+
return [d["id"] for d, _ in files]
|
api/db/services/file_service.py
CHANGED
@@ -401,7 +401,7 @@ class FileService(CommonService):
|
|
401 |
ParserType.AUDIO.value: audio,
|
402 |
ParserType.EMAIL.value: email
|
403 |
}
|
404 |
-
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize":
|
405 |
exe = ThreadPoolExecutor(max_workers=12)
|
406 |
threads = []
|
407 |
for file in file_objs:
|
|
|
401 |
ParserType.AUDIO.value: audio,
|
402 |
ParserType.EMAIL.value: email
|
403 |
}
|
404 |
+
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
405 |
exe = ThreadPoolExecutor(max_workers=12)
|
406 |
threads = []
|
407 |
for file in file_objs:
|
api/db/services/task_service.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16 |
import os
|
17 |
import random
|
18 |
import xxhash
|
19 |
-
import bisect
|
20 |
from datetime import datetime
|
21 |
|
22 |
from api.db.db_utils import bulk_insert_into_db
|
@@ -183,7 +182,7 @@ class TaskService(CommonService):
|
|
183 |
if os.environ.get("MACOS"):
|
184 |
if info["progress_msg"]:
|
185 |
task = cls.model.get_by_id(id)
|
186 |
-
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"],
|
187 |
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
188 |
if "progress" in info:
|
189 |
cls.model.update(progress=info["progress"]).where(
|
@@ -194,7 +193,7 @@ class TaskService(CommonService):
|
|
194 |
with DB.lock("update_progress", -1):
|
195 |
if info["progress_msg"]:
|
196 |
task = cls.model.get_by_id(id)
|
197 |
-
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"],
|
198 |
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
199 |
if "progress" in info:
|
200 |
cls.model.update(progress=info["progress"]).where(
|
@@ -210,12 +209,12 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|
210 |
|
211 |
if doc["type"] == FileType.PDF.value:
|
212 |
file_bin = STORAGE_IMPL.get(bucket, name)
|
213 |
-
do_layout = doc["parser_config"].get("layout_recognize",
|
214 |
pages = PdfParser.total_page_number(doc["name"], file_bin)
|
215 |
page_size = doc["parser_config"].get("task_page_size", 12)
|
216 |
if doc["parser_id"] == "paper":
|
217 |
page_size = doc["parser_config"].get("task_page_size", 22)
|
218 |
-
if doc["parser_id"] in ["one", "knowledge_graph"] or
|
219 |
page_size = 10 ** 9
|
220 |
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
221 |
for s, e in page_ranges:
|
@@ -243,6 +242,10 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|
243 |
for task in parse_task_array:
|
244 |
hasher = xxhash.xxh64()
|
245 |
for field in sorted(chunking_config.keys()):
|
|
|
|
|
|
|
|
|
246 |
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
247 |
for field in ["doc_id", "from_page", "to_page"]:
|
248 |
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
@@ -276,20 +279,27 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|
276 |
|
277 |
|
278 |
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
279 |
-
idx =
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
if idx >= len(prev_tasks):
|
282 |
return 0
|
283 |
prev_task = prev_tasks[idx]
|
284 |
-
if prev_task["progress"] < 1.0 or
|
285 |
return 0
|
286 |
task["chunk_ids"] = prev_task["chunk_ids"]
|
287 |
task["progress"] = 1.0
|
288 |
-
if "from_page" in task and "to_page" in task:
|
289 |
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
|
290 |
else:
|
291 |
task["progress_msg"] = ""
|
292 |
-
task["progress_msg"]
|
|
|
293 |
prev_task["chunk_ids"] = ""
|
294 |
|
295 |
return len(task["chunk_ids"].split())
|
|
|
16 |
import os
|
17 |
import random
|
18 |
import xxhash
|
|
|
19 |
from datetime import datetime
|
20 |
|
21 |
from api.db.db_utils import bulk_insert_into_db
|
|
|
182 |
if os.environ.get("MACOS"):
|
183 |
if info["progress_msg"]:
|
184 |
task = cls.model.get_by_id(id)
|
185 |
+
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
186 |
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
187 |
if "progress" in info:
|
188 |
cls.model.update(progress=info["progress"]).where(
|
|
|
193 |
with DB.lock("update_progress", -1):
|
194 |
if info["progress_msg"]:
|
195 |
task = cls.model.get_by_id(id)
|
196 |
+
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
197 |
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
198 |
if "progress" in info:
|
199 |
cls.model.update(progress=info["progress"]).where(
|
|
|
209 |
|
210 |
if doc["type"] == FileType.PDF.value:
|
211 |
file_bin = STORAGE_IMPL.get(bucket, name)
|
212 |
+
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
|
213 |
pages = PdfParser.total_page_number(doc["name"], file_bin)
|
214 |
page_size = doc["parser_config"].get("task_page_size", 12)
|
215 |
if doc["parser_id"] == "paper":
|
216 |
page_size = doc["parser_config"].get("task_page_size", 22)
|
217 |
+
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
|
218 |
page_size = 10 ** 9
|
219 |
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
220 |
for s, e in page_ranges:
|
|
|
242 |
for task in parse_task_array:
|
243 |
hasher = xxhash.xxh64()
|
244 |
for field in sorted(chunking_config.keys()):
|
245 |
+
if field == "parser_config":
|
246 |
+
for k in ["raptor", "graphrag"]:
|
247 |
+
if k in chunking_config[field]:
|
248 |
+
del chunking_config[field][k]
|
249 |
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
250 |
for field in ["doc_id", "from_page", "to_page"]:
|
251 |
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
|
|
279 |
|
280 |
|
281 |
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
282 |
+
idx = 0
|
283 |
+
while idx < len(prev_tasks):
|
284 |
+
prev_task = prev_tasks[idx]
|
285 |
+
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
|
286 |
+
and prev_task.get("digest", 0) == task.get("digest", ""):
|
287 |
+
break
|
288 |
+
idx += 1
|
289 |
+
|
290 |
if idx >= len(prev_tasks):
|
291 |
return 0
|
292 |
prev_task = prev_tasks[idx]
|
293 |
+
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
|
294 |
return 0
|
295 |
task["chunk_ids"] = prev_task["chunk_ids"]
|
296 |
task["progress"] = 1.0
|
297 |
+
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
|
298 |
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
|
299 |
else:
|
300 |
task["progress_msg"] = ""
|
301 |
+
task["progress_msg"] = " ".join(
|
302 |
+
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
|
303 |
prev_task["chunk_ids"] = ""
|
304 |
|
305 |
return len(task["chunk_ids"].split())
|
api/utils/api_utils.py
CHANGED
@@ -355,7 +355,7 @@ def get_parser_config(chunk_method, parser_config):
|
|
355 |
if not chunk_method:
|
356 |
chunk_method = "naive"
|
357 |
key_mapping = {
|
358 |
-
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize":
|
359 |
"raptor": {"use_raptor": False}},
|
360 |
"qa": {"raptor": {"use_raptor": False}},
|
361 |
"tag": None,
|
|
|
355 |
if not chunk_method:
|
356 |
chunk_method = "naive"
|
357 |
key_mapping = {
|
358 |
+
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC",
|
359 |
"raptor": {"use_raptor": False}},
|
360 |
"qa": {"raptor": {"use_raptor": False}},
|
361 |
"tag": None,
|
conf/infinity_mapping.json
CHANGED
@@ -25,9 +25,19 @@
|
|
25 |
"weight_int": {"type": "integer", "default": 0},
|
26 |
"weight_flt": {"type": "float", "default": 0.0},
|
27 |
"rank_int": {"type": "integer", "default": 0},
|
|
|
28 |
"available_int": {"type": "integer", "default": 1},
|
29 |
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
30 |
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
31 |
"pagerank_fea": {"type": "integer", "default": 0},
|
32 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
}
|
|
|
25 |
"weight_int": {"type": "integer", "default": 0},
|
26 |
"weight_flt": {"type": "float", "default": 0.0},
|
27 |
"rank_int": {"type": "integer", "default": 0},
|
28 |
+
"rank_flt": {"type": "float", "default": 0},
|
29 |
"available_int": {"type": "integer", "default": 1},
|
30 |
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
31 |
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
32 |
"pagerank_fea": {"type": "integer", "default": 0},
|
33 |
+
"tag_feas": {"type": "integer", "default": 0},
|
34 |
+
|
35 |
+
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
36 |
+
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
37 |
+
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
38 |
+
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
39 |
+
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
40 |
+
"source_id": {"type": "varchar", "default": ""},
|
41 |
+
"n_hop_with_weight": {"type": "varchar", "default": ""},
|
42 |
+
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}
|
43 |
}
|
graphrag/description_summary.py
DELETED
@@ -1,146 +0,0 @@
|
|
1 |
-
# Copyright (c) 2024 Microsoft Corporation.
|
2 |
-
# Licensed under the MIT License
|
3 |
-
"""
|
4 |
-
Reference:
|
5 |
-
- [graphrag](https://github.com/microsoft/graphrag)
|
6 |
-
"""
|
7 |
-
|
8 |
-
import json
|
9 |
-
from dataclasses import dataclass
|
10 |
-
|
11 |
-
from graphrag.extractor import Extractor
|
12 |
-
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
13 |
-
from rag.llm.chat_model import Base as CompletionLLM
|
14 |
-
|
15 |
-
from rag.utils import num_tokens_from_string
|
16 |
-
|
17 |
-
SUMMARIZE_PROMPT = """
|
18 |
-
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
19 |
-
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
20 |
-
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
21 |
-
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
22 |
-
Make sure it is written in third person, and include the entity names so we the have full context.
|
23 |
-
|
24 |
-
#######
|
25 |
-
-Data-
|
26 |
-
Entities: {entity_name}
|
27 |
-
Description List: {description_list}
|
28 |
-
#######
|
29 |
-
Output:
|
30 |
-
"""
|
31 |
-
|
32 |
-
# Max token size for input prompts
|
33 |
-
DEFAULT_MAX_INPUT_TOKENS = 4_000
|
34 |
-
# Max token count for LLM answers
|
35 |
-
DEFAULT_MAX_SUMMARY_LENGTH = 128
|
36 |
-
|
37 |
-
|
38 |
-
@dataclass
|
39 |
-
class SummarizationResult:
|
40 |
-
"""Unipartite graph extraction result class definition."""
|
41 |
-
|
42 |
-
items: str | tuple[str, str]
|
43 |
-
description: str
|
44 |
-
|
45 |
-
|
46 |
-
class SummarizeExtractor(Extractor):
|
47 |
-
"""Unipartite graph extractor class definition."""
|
48 |
-
|
49 |
-
_entity_name_key: str
|
50 |
-
_input_descriptions_key: str
|
51 |
-
_summarization_prompt: str
|
52 |
-
_on_error: ErrorHandlerFn
|
53 |
-
_max_summary_length: int
|
54 |
-
_max_input_tokens: int
|
55 |
-
|
56 |
-
def __init__(
|
57 |
-
self,
|
58 |
-
llm_invoker: CompletionLLM,
|
59 |
-
entity_name_key: str | None = None,
|
60 |
-
input_descriptions_key: str | None = None,
|
61 |
-
summarization_prompt: str | None = None,
|
62 |
-
on_error: ErrorHandlerFn | None = None,
|
63 |
-
max_summary_length: int | None = None,
|
64 |
-
max_input_tokens: int | None = None,
|
65 |
-
):
|
66 |
-
"""Init method definition."""
|
67 |
-
# TODO: streamline construction
|
68 |
-
self._llm = llm_invoker
|
69 |
-
self._entity_name_key = entity_name_key or "entity_name"
|
70 |
-
self._input_descriptions_key = input_descriptions_key or "description_list"
|
71 |
-
|
72 |
-
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
73 |
-
self._on_error = on_error or (lambda _e, _s, _d: None)
|
74 |
-
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
|
75 |
-
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
|
76 |
-
|
77 |
-
def __call__(
|
78 |
-
self,
|
79 |
-
items: str | tuple[str, str],
|
80 |
-
descriptions: list[str],
|
81 |
-
) -> SummarizationResult:
|
82 |
-
"""Call method definition."""
|
83 |
-
result = ""
|
84 |
-
if len(descriptions) == 0:
|
85 |
-
result = ""
|
86 |
-
if len(descriptions) == 1:
|
87 |
-
result = descriptions[0]
|
88 |
-
else:
|
89 |
-
result = self._summarize_descriptions(items, descriptions)
|
90 |
-
|
91 |
-
return SummarizationResult(
|
92 |
-
items=items,
|
93 |
-
description=result or "",
|
94 |
-
)
|
95 |
-
|
96 |
-
def _summarize_descriptions(
|
97 |
-
self, items: str | tuple[str, str], descriptions: list[str]
|
98 |
-
) -> str:
|
99 |
-
"""Summarize descriptions into a single description."""
|
100 |
-
sorted_items = sorted(items) if isinstance(items, list) else items
|
101 |
-
|
102 |
-
# Safety check, should always be a list
|
103 |
-
if not isinstance(descriptions, list):
|
104 |
-
descriptions = [descriptions]
|
105 |
-
|
106 |
-
# Iterate over descriptions, adding all until the max input tokens is reached
|
107 |
-
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
108 |
-
self._summarization_prompt
|
109 |
-
)
|
110 |
-
descriptions_collected = []
|
111 |
-
result = ""
|
112 |
-
|
113 |
-
for i, description in enumerate(descriptions):
|
114 |
-
usable_tokens -= num_tokens_from_string(description)
|
115 |
-
descriptions_collected.append(description)
|
116 |
-
|
117 |
-
# If buffer is full, or all descriptions have been added, summarize
|
118 |
-
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
|
119 |
-
i == len(descriptions) - 1
|
120 |
-
):
|
121 |
-
# Calculate result (final or partial)
|
122 |
-
result = await self._summarize_descriptions_with_llm(
|
123 |
-
sorted_items, descriptions_collected
|
124 |
-
)
|
125 |
-
|
126 |
-
# If we go for another loop, reset values to new
|
127 |
-
if i != len(descriptions) - 1:
|
128 |
-
descriptions_collected = [result]
|
129 |
-
usable_tokens = (
|
130 |
-
self._max_input_tokens
|
131 |
-
- num_tokens_from_string(self._summarization_prompt)
|
132 |
-
- num_tokens_from_string(result)
|
133 |
-
)
|
134 |
-
|
135 |
-
return result
|
136 |
-
|
137 |
-
def _summarize_descriptions_with_llm(
|
138 |
-
self, items: str | tuple[str, str] | list[str], descriptions: list[str]
|
139 |
-
):
|
140 |
-
"""Summarize descriptions using the LLM."""
|
141 |
-
variables = {
|
142 |
-
self._entity_name_key: json.dumps(items),
|
143 |
-
self._input_descriptions_key: json.dumps(sorted(descriptions)),
|
144 |
-
}
|
145 |
-
text = perform_variable_replacements(self._summarization_prompt, variables=variables)
|
146 |
-
return self._chat("", [{"role": "user", "content": text}])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphrag/entity_resolution.py
CHANGED
@@ -16,18 +16,18 @@
|
|
16 |
import logging
|
17 |
import itertools
|
18 |
import re
|
19 |
-
import
|
20 |
from dataclasses import dataclass
|
21 |
-
from typing import Any
|
22 |
|
23 |
import networkx as nx
|
24 |
|
25 |
-
from graphrag.extractor import Extractor
|
26 |
from rag.nlp import is_english
|
27 |
import editdistance
|
28 |
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
29 |
from rag.llm.chat_model import Base as CompletionLLM
|
30 |
-
from graphrag.utils import
|
31 |
|
32 |
DEFAULT_RECORD_DELIMITER = "##"
|
33 |
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
@@ -37,8 +37,8 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
|
37 |
@dataclass
|
38 |
class EntityResolutionResult:
|
39 |
"""Entity resolution result class definition."""
|
40 |
-
|
41 |
-
|
42 |
|
43 |
|
44 |
class EntityResolution(Extractor):
|
@@ -46,7 +46,6 @@ class EntityResolution(Extractor):
|
|
46 |
|
47 |
_resolution_prompt: str
|
48 |
_output_formatter_prompt: str
|
49 |
-
_on_error: ErrorHandlerFn
|
50 |
_record_delimiter_key: str
|
51 |
_entity_index_delimiter_key: str
|
52 |
_resolution_result_delimiter_key: str
|
@@ -54,21 +53,19 @@ class EntityResolution(Extractor):
|
|
54 |
def __init__(
|
55 |
self,
|
56 |
llm_invoker: CompletionLLM,
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
resolution_result_delimiter_key: str | None = None,
|
62 |
-
input_text_key: str | None = None
|
63 |
):
|
|
|
64 |
"""Init method definition."""
|
65 |
self._llm = llm_invoker
|
66 |
-
self._resolution_prompt =
|
67 |
-
self.
|
68 |
-
self.
|
69 |
-
self.
|
70 |
-
self.
|
71 |
-
self._input_text_key = input_text_key or "input_text"
|
72 |
|
73 |
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
74 |
"""Call method definition."""
|
@@ -87,11 +84,11 @@ class EntityResolution(Extractor):
|
|
87 |
}
|
88 |
|
89 |
nodes = graph.nodes
|
90 |
-
entity_types = list(set(graph.nodes[node]
|
91 |
node_clusters = {entity_type: [] for entity_type in entity_types}
|
92 |
|
93 |
for node in nodes:
|
94 |
-
node_clusters[graph.nodes[node]
|
95 |
|
96 |
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
97 |
for k, v in node_clusters.items():
|
@@ -128,44 +125,51 @@ class EntityResolution(Extractor):
|
|
128 |
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
129 |
for result_i in result:
|
130 |
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
131 |
-
except Exception
|
132 |
logging.exception("error entity resolution")
|
133 |
-
self._on_error(e, traceback.format_exc(), None)
|
134 |
|
135 |
connect_graph = nx.Graph()
|
|
|
136 |
connect_graph.add_edges_from(resolution_result)
|
137 |
for sub_connect_graph in nx.connected_components(connect_graph):
|
138 |
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
139 |
remove_nodes = list(sub_connect_graph.nodes)
|
140 |
keep_node = remove_nodes.pop()
|
|
|
141 |
for remove_node in remove_nodes:
|
|
|
142 |
remove_node_neighbors = graph[remove_node]
|
143 |
-
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
|
144 |
-
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
|
145 |
remove_node_neighbors = list(remove_node_neighbors)
|
146 |
for remove_node_neighbor in remove_node_neighbors:
|
|
|
|
|
|
|
147 |
if remove_node_neighbor == keep_node:
|
148 |
-
graph.
|
|
|
|
|
|
|
149 |
continue
|
150 |
if graph.has_edge(keep_node, remove_node_neighbor):
|
151 |
-
|
152 |
-
'weight']
|
153 |
-
graph[keep_node][remove_node_neighbor]['description'] += \
|
154 |
-
graph[remove_node][remove_node_neighbor]['description']
|
155 |
-
graph.remove_edge(remove_node, remove_node_neighbor)
|
156 |
else:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
graph.remove_node(remove_node)
|
163 |
|
164 |
-
for node_degree in graph.degree:
|
165 |
-
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
166 |
-
|
167 |
return EntityResolutionResult(
|
168 |
-
|
|
|
169 |
)
|
170 |
|
171 |
def _process_results(
|
|
|
16 |
import logging
|
17 |
import itertools
|
18 |
import re
|
19 |
+
import time
|
20 |
from dataclasses import dataclass
|
21 |
+
from typing import Any, Callable
|
22 |
|
23 |
import networkx as nx
|
24 |
|
25 |
+
from graphrag.general.extractor import Extractor
|
26 |
from rag.nlp import is_english
|
27 |
import editdistance
|
28 |
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
29 |
from rag.llm.chat_model import Base as CompletionLLM
|
30 |
+
from graphrag.utils import perform_variable_replacements
|
31 |
|
32 |
DEFAULT_RECORD_DELIMITER = "##"
|
33 |
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
|
|
37 |
@dataclass
|
38 |
class EntityResolutionResult:
|
39 |
"""Entity resolution result class definition."""
|
40 |
+
graph: nx.Graph
|
41 |
+
removed_entities: list
|
42 |
|
43 |
|
44 |
class EntityResolution(Extractor):
|
|
|
46 |
|
47 |
_resolution_prompt: str
|
48 |
_output_formatter_prompt: str
|
|
|
49 |
_record_delimiter_key: str
|
50 |
_entity_index_delimiter_key: str
|
51 |
_resolution_result_delimiter_key: str
|
|
|
53 |
def __init__(
|
54 |
self,
|
55 |
llm_invoker: CompletionLLM,
|
56 |
+
get_entity: Callable | None = None,
|
57 |
+
set_entity: Callable | None = None,
|
58 |
+
get_relation: Callable | None = None,
|
59 |
+
set_relation: Callable | None = None
|
|
|
|
|
60 |
):
|
61 |
+
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
62 |
"""Init method definition."""
|
63 |
self._llm = llm_invoker
|
64 |
+
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
|
65 |
+
self._record_delimiter_key = "record_delimiter"
|
66 |
+
self._entity_index_dilimiter_key = "entity_index_delimiter"
|
67 |
+
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
68 |
+
self._input_text_key = "input_text"
|
|
|
69 |
|
70 |
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
71 |
"""Call method definition."""
|
|
|
84 |
}
|
85 |
|
86 |
nodes = graph.nodes
|
87 |
+
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
|
88 |
node_clusters = {entity_type: [] for entity_type in entity_types}
|
89 |
|
90 |
for node in nodes:
|
91 |
+
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
|
92 |
|
93 |
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
94 |
for k, v in node_clusters.items():
|
|
|
125 |
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
126 |
for result_i in result:
|
127 |
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
128 |
+
except Exception:
|
129 |
logging.exception("error entity resolution")
|
|
|
130 |
|
131 |
connect_graph = nx.Graph()
|
132 |
+
removed_entities = []
|
133 |
connect_graph.add_edges_from(resolution_result)
|
134 |
for sub_connect_graph in nx.connected_components(connect_graph):
|
135 |
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
136 |
remove_nodes = list(sub_connect_graph.nodes)
|
137 |
keep_node = remove_nodes.pop()
|
138 |
+
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
|
139 |
for remove_node in remove_nodes:
|
140 |
+
removed_entities.append(remove_node)
|
141 |
remove_node_neighbors = graph[remove_node]
|
|
|
|
|
142 |
remove_node_neighbors = list(remove_node_neighbors)
|
143 |
for remove_node_neighbor in remove_node_neighbors:
|
144 |
+
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
145 |
+
if graph.has_edge(remove_node, remove_node_neighbor):
|
146 |
+
graph.remove_edge(remove_node, remove_node_neighbor)
|
147 |
if remove_node_neighbor == keep_node:
|
148 |
+
if graph.has_edge(keep_node, remove_node):
|
149 |
+
graph.remove_edge(keep_node, remove_node)
|
150 |
+
continue
|
151 |
+
if not rel:
|
152 |
continue
|
153 |
if graph.has_edge(keep_node, remove_node_neighbor):
|
154 |
+
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
|
|
|
|
|
|
|
|
155 |
else:
|
156 |
+
pair = sorted([keep_node, remove_node_neighbor])
|
157 |
+
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
158 |
+
self._set_relation_(pair[0], pair[1],
|
159 |
+
dict(
|
160 |
+
src_id=pair[0],
|
161 |
+
tgt_id=pair[1],
|
162 |
+
weight=rel['weight'],
|
163 |
+
description=rel['description'],
|
164 |
+
keywords=[],
|
165 |
+
source_id=rel.get("source_id", ""),
|
166 |
+
metadata={"created_at": time.time()}
|
167 |
+
))
|
168 |
graph.remove_node(remove_node)
|
169 |
|
|
|
|
|
|
|
170 |
return EntityResolutionResult(
|
171 |
+
graph=graph,
|
172 |
+
removed_entities=removed_entities
|
173 |
)
|
174 |
|
175 |
def _process_results(
|
graphrag/extractor.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
#
|
16 |
-
from graphrag.utils import get_llm_cache, set_llm_cache
|
17 |
-
from rag.llm.chat_model import Base as CompletionLLM
|
18 |
-
|
19 |
-
|
20 |
-
class Extractor:
|
21 |
-
_llm: CompletionLLM
|
22 |
-
|
23 |
-
def __init__(self, llm_invoker: CompletionLLM):
|
24 |
-
self._llm = llm_invoker
|
25 |
-
|
26 |
-
def _chat(self, system, history, gen_conf):
|
27 |
-
response = get_llm_cache(self._llm.llm_name, system, history, gen_conf)
|
28 |
-
if response:
|
29 |
-
return response
|
30 |
-
response = self._llm.chat(system, history, gen_conf)
|
31 |
-
if response.find("**ERROR**") >= 0:
|
32 |
-
raise Exception(response)
|
33 |
-
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
34 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphrag/general/__init__.py
ADDED
File without changes
|
graphrag/{claim_extractor.py → general/claim_extractor.py}
RENAMED
@@ -15,8 +15,8 @@ from typing import Any
|
|
15 |
|
16 |
import tiktoken
|
17 |
|
18 |
-
from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
19 |
-
from graphrag.extractor import Extractor
|
20 |
from rag.llm.chat_model import Base as CompletionLLM
|
21 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
22 |
|
|
|
15 |
|
16 |
import tiktoken
|
17 |
|
18 |
+
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
19 |
+
from graphrag.general.extractor import Extractor
|
20 |
from rag.llm.chat_model import Base as CompletionLLM
|
21 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
22 |
|
graphrag/{claim_prompt.py → general/claim_prompt.py}
RENAMED
File without changes
|
graphrag/{community_report_prompt.py → general/community_report_prompt.py}
RENAMED
File without changes
|
graphrag/{community_reports_extractor.py → general/community_reports_extractor.py}
RENAMED
@@ -13,10 +13,10 @@ from typing import Callable
|
|
13 |
from dataclasses import dataclass
|
14 |
import networkx as nx
|
15 |
import pandas as pd
|
16 |
-
from graphrag import leiden
|
17 |
-
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
18 |
-
from graphrag.extractor import Extractor
|
19 |
-
from graphrag.leiden import add_community_info2graph
|
20 |
from rag.llm.chat_model import Base as CompletionLLM
|
21 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
22 |
from rag.utils import num_tokens_from_string
|
@@ -40,32 +40,43 @@ class CommunityReportsExtractor(Extractor):
|
|
40 |
_max_report_length: int
|
41 |
|
42 |
def __init__(
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
):
|
|
|
49 |
"""Init method definition."""
|
50 |
self._llm = llm_invoker
|
51 |
-
self._extraction_prompt =
|
52 |
-
self._on_error = on_error or (lambda _e, _s, _d: None)
|
53 |
self._max_report_length = max_report_length or 1500
|
54 |
|
55 |
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
|
|
|
|
|
|
56 |
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
57 |
total = sum([len(comm.items()) for _, comm in communities.items()])
|
58 |
-
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
59 |
res_str = []
|
60 |
res_dict = []
|
61 |
over, token_count = 0, 0
|
62 |
st = timer()
|
63 |
for level, comm in communities.items():
|
|
|
64 |
for cm_id, ents in comm.items():
|
65 |
weight = ents["weight"]
|
66 |
ents = ents["nodes"]
|
67 |
-
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
prompt_variables = {
|
71 |
"entity_df": ent_df.to_csv(index_label="id"),
|
|
|
13 |
from dataclasses import dataclass
|
14 |
import networkx as nx
|
15 |
import pandas as pd
|
16 |
+
from graphrag.general import leiden
|
17 |
+
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
18 |
+
from graphrag.general.extractor import Extractor
|
19 |
+
from graphrag.general.leiden import add_community_info2graph
|
20 |
from rag.llm.chat_model import Base as CompletionLLM
|
21 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
22 |
from rag.utils import num_tokens_from_string
|
|
|
40 |
_max_report_length: int
|
41 |
|
42 |
def __init__(
|
43 |
+
self,
|
44 |
+
llm_invoker: CompletionLLM,
|
45 |
+
get_entity: Callable | None = None,
|
46 |
+
set_entity: Callable | None = None,
|
47 |
+
get_relation: Callable | None = None,
|
48 |
+
set_relation: Callable | None = None,
|
49 |
+
max_report_length: int | None = None,
|
50 |
):
|
51 |
+
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
52 |
"""Init method definition."""
|
53 |
self._llm = llm_invoker
|
54 |
+
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
|
|
55 |
self._max_report_length = max_report_length or 1500
|
56 |
|
57 |
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
58 |
+
for node_degree in graph.degree:
|
59 |
+
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
60 |
+
|
61 |
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
62 |
total = sum([len(comm.items()) for _, comm in communities.items()])
|
|
|
63 |
res_str = []
|
64 |
res_dict = []
|
65 |
over, token_count = 0, 0
|
66 |
st = timer()
|
67 |
for level, comm in communities.items():
|
68 |
+
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
69 |
for cm_id, ents in comm.items():
|
70 |
weight = ents["weight"]
|
71 |
ents = ents["nodes"]
|
72 |
+
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
|
73 |
+
ent_df["entity"] = ent_df["entity_name"]
|
74 |
+
del ent_df["entity_name"]
|
75 |
+
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
76 |
+
rela_df["source"] = rela_df["src_id"]
|
77 |
+
rela_df["target"] = rela_df["tgt_id"]
|
78 |
+
del rela_df["src_id"]
|
79 |
+
del rela_df["tgt_id"]
|
80 |
|
81 |
prompt_variables = {
|
82 |
"entity_df": ent_df.to_csv(index_label="id"),
|
graphrag/{entity_embedding.py → general/entity_embedding.py}
RENAMED
@@ -9,7 +9,7 @@ from typing import Any
|
|
9 |
import numpy as np
|
10 |
import networkx as nx
|
11 |
from dataclasses import dataclass
|
12 |
-
from graphrag.leiden import stable_largest_connected_component
|
13 |
import graspologic as gc
|
14 |
|
15 |
|
|
|
9 |
import numpy as np
|
10 |
import networkx as nx
|
11 |
from dataclasses import dataclass
|
12 |
+
from graphrag.general.leiden import stable_largest_connected_component
|
13 |
import graspologic as gc
|
14 |
|
15 |
|
graphrag/general/extractor.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
from collections import defaultdict, Counter
|
19 |
+
from concurrent.futures import ThreadPoolExecutor
|
20 |
+
from copy import deepcopy
|
21 |
+
from typing import Callable
|
22 |
+
|
23 |
+
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
24 |
+
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
25 |
+
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
|
26 |
+
from rag.llm.chat_model import Base as CompletionLLM
|
27 |
+
from rag.utils import truncate
|
28 |
+
|
29 |
+
GRAPH_FIELD_SEP = "<SEP>"
|
30 |
+
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
|
31 |
+
ENTITY_EXTRACTION_MAX_GLEANINGS = 2
|
32 |
+
|
33 |
+
|
34 |
+
class Extractor:
|
35 |
+
_llm: CompletionLLM
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
llm_invoker: CompletionLLM,
|
40 |
+
language: str | None = "English",
|
41 |
+
entity_types: list[str] | None = None,
|
42 |
+
get_entity: Callable | None = None,
|
43 |
+
set_entity: Callable | None = None,
|
44 |
+
get_relation: Callable | None = None,
|
45 |
+
set_relation: Callable | None = None,
|
46 |
+
):
|
47 |
+
self._llm = llm_invoker
|
48 |
+
self._language = language
|
49 |
+
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
50 |
+
self._get_entity_ = get_entity
|
51 |
+
self._set_entity_ = set_entity
|
52 |
+
self._get_relation_ = get_relation
|
53 |
+
self._set_relation_ = set_relation
|
54 |
+
|
55 |
+
def _chat(self, system, history, gen_conf):
|
56 |
+
hist = deepcopy(history)
|
57 |
+
conf = deepcopy(gen_conf)
|
58 |
+
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
|
59 |
+
if response:
|
60 |
+
return response
|
61 |
+
response = self._llm.chat(system, hist, conf)
|
62 |
+
if response.find("**ERROR**") >= 0:
|
63 |
+
raise Exception(response)
|
64 |
+
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
65 |
+
return response
|
66 |
+
|
67 |
+
def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
|
68 |
+
maybe_nodes = defaultdict(list)
|
69 |
+
maybe_edges = defaultdict(list)
|
70 |
+
ent_types = [t.lower() for t in self._entity_types]
|
71 |
+
for record in records:
|
72 |
+
record_attributes = split_string_by_multi_markers(
|
73 |
+
record, [tuple_delimiter]
|
74 |
+
)
|
75 |
+
|
76 |
+
if_entities = handle_single_entity_extraction(
|
77 |
+
record_attributes, chunk_key
|
78 |
+
)
|
79 |
+
if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
|
80 |
+
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
81 |
+
continue
|
82 |
+
|
83 |
+
if_relation = handle_single_relationship_extraction(
|
84 |
+
record_attributes, chunk_key
|
85 |
+
)
|
86 |
+
if if_relation is not None:
|
87 |
+
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
88 |
+
if_relation
|
89 |
+
)
|
90 |
+
return dict(maybe_nodes), dict(maybe_edges)
|
91 |
+
|
92 |
+
def __call__(
|
93 |
+
self, chunks: list[tuple[str, str]],
|
94 |
+
callback: Callable | None = None
|
95 |
+
):
|
96 |
+
|
97 |
+
results = []
|
98 |
+
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
|
99 |
+
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
100 |
+
threads = []
|
101 |
+
for i, (cid, ck) in enumerate(chunks):
|
102 |
+
threads.append(
|
103 |
+
exe.submit(self._process_single_content, (cid, ck)))
|
104 |
+
|
105 |
+
for i, _ in enumerate(threads):
|
106 |
+
n, r, tc = _.result()
|
107 |
+
if not isinstance(n, Exception):
|
108 |
+
results.append((n, r))
|
109 |
+
if callback:
|
110 |
+
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
|
111 |
+
elif callback:
|
112 |
+
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
|
113 |
+
|
114 |
+
maybe_nodes = defaultdict(list)
|
115 |
+
maybe_edges = defaultdict(list)
|
116 |
+
for m_nodes, m_edges in results:
|
117 |
+
for k, v in m_nodes.items():
|
118 |
+
maybe_nodes[k].extend(v)
|
119 |
+
for k, v in m_edges.items():
|
120 |
+
maybe_edges[tuple(sorted(k))].extend(v)
|
121 |
+
logging.info("Inserting entities into storage...")
|
122 |
+
all_entities_data = []
|
123 |
+
for en_nm, ents in maybe_nodes.items():
|
124 |
+
all_entities_data.append(self._merge_nodes(en_nm, ents))
|
125 |
+
|
126 |
+
logging.info("Inserting relationships into storage...")
|
127 |
+
all_relationships_data = []
|
128 |
+
for (src,tgt), rels in maybe_edges.items():
|
129 |
+
all_relationships_data.append(self._merge_edges(src, tgt, rels))
|
130 |
+
|
131 |
+
if not len(all_entities_data) and not len(all_relationships_data):
|
132 |
+
logging.warning(
|
133 |
+
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
134 |
+
)
|
135 |
+
|
136 |
+
if not len(all_entities_data):
|
137 |
+
logging.warning("Didn't extract any entities")
|
138 |
+
if not len(all_relationships_data):
|
139 |
+
logging.warning("Didn't extract any relationships")
|
140 |
+
|
141 |
+
return all_entities_data, all_relationships_data
|
142 |
+
|
143 |
+
def _merge_nodes(self, entity_name: str, entities: list[dict]):
|
144 |
+
if not entities:
|
145 |
+
return
|
146 |
+
already_entity_types = []
|
147 |
+
already_source_ids = []
|
148 |
+
already_description = []
|
149 |
+
|
150 |
+
already_node = self._get_entity_(entity_name)
|
151 |
+
if already_node:
|
152 |
+
already_entity_types.append(already_node["entity_type"])
|
153 |
+
already_source_ids.extend(already_node["source_id"])
|
154 |
+
already_description.append(already_node["description"])
|
155 |
+
|
156 |
+
entity_type = sorted(
|
157 |
+
Counter(
|
158 |
+
[dp["entity_type"] for dp in entities] + already_entity_types
|
159 |
+
).items(),
|
160 |
+
key=lambda x: x[1],
|
161 |
+
reverse=True,
|
162 |
+
)[0][0]
|
163 |
+
description = GRAPH_FIELD_SEP.join(
|
164 |
+
sorted(set([dp["description"] for dp in entities] + already_description))
|
165 |
+
)
|
166 |
+
already_source_ids = flat_uniq_list(entities, "source_id")
|
167 |
+
description = self._handle_entity_relation_summary(
|
168 |
+
entity_name, description
|
169 |
+
)
|
170 |
+
node_data = dict(
|
171 |
+
entity_type=entity_type,
|
172 |
+
description=description,
|
173 |
+
source_id=already_source_ids,
|
174 |
+
)
|
175 |
+
node_data["entity_name"] = entity_name
|
176 |
+
self._set_entity_(entity_name, node_data)
|
177 |
+
return node_data
|
178 |
+
|
179 |
+
def _merge_edges(
|
180 |
+
self,
|
181 |
+
src_id: str,
|
182 |
+
tgt_id: str,
|
183 |
+
edges_data: list[dict]
|
184 |
+
):
|
185 |
+
if not edges_data:
|
186 |
+
return
|
187 |
+
already_weights = []
|
188 |
+
already_source_ids = []
|
189 |
+
already_description = []
|
190 |
+
already_keywords = []
|
191 |
+
|
192 |
+
relation = self._get_relation_(src_id, tgt_id)
|
193 |
+
if relation:
|
194 |
+
already_weights = [relation["weight"]]
|
195 |
+
already_source_ids = relation["source_id"]
|
196 |
+
already_description = [relation["description"]]
|
197 |
+
already_keywords = relation["keywords"]
|
198 |
+
|
199 |
+
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
200 |
+
description = GRAPH_FIELD_SEP.join(
|
201 |
+
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
202 |
+
)
|
203 |
+
keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
|
204 |
+
source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
|
205 |
+
|
206 |
+
for need_insert_id in [src_id, tgt_id]:
|
207 |
+
if self._get_entity_(need_insert_id):
|
208 |
+
continue
|
209 |
+
self._set_entity_(need_insert_id, {
|
210 |
+
"source_id": source_id,
|
211 |
+
"description": description,
|
212 |
+
"entity_type": 'UNKNOWN'
|
213 |
+
})
|
214 |
+
description = self._handle_entity_relation_summary(
|
215 |
+
f"({src_id}, {tgt_id})", description
|
216 |
+
)
|
217 |
+
edge_data = dict(
|
218 |
+
src_id=src_id,
|
219 |
+
tgt_id=tgt_id,
|
220 |
+
description=description,
|
221 |
+
keywords=keywords,
|
222 |
+
weight=weight,
|
223 |
+
source_id=source_id
|
224 |
+
)
|
225 |
+
self._set_relation_(src_id, tgt_id, edge_data)
|
226 |
+
|
227 |
+
return edge_data
|
228 |
+
|
229 |
+
def _handle_entity_relation_summary(
|
230 |
+
self,
|
231 |
+
entity_or_relation_name: str,
|
232 |
+
description: str
|
233 |
+
) -> str:
|
234 |
+
summary_max_tokens = 512
|
235 |
+
use_description = truncate(description, summary_max_tokens)
|
236 |
+
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
|
237 |
+
context_base = dict(
|
238 |
+
entity_name=entity_or_relation_name,
|
239 |
+
description_list=use_description.split(GRAPH_FIELD_SEP),
|
240 |
+
language=self._language,
|
241 |
+
)
|
242 |
+
use_prompt = prompt_template.format(**context_base)
|
243 |
+
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
244 |
+
summary = self._chat(use_prompt, [{"role": "assistant", "content": "Output: "}], {"temperature": 0.8})
|
245 |
+
return summary
|
graphrag/general/graph_extractor.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License
|
3 |
+
"""
|
4 |
+
Reference:
|
5 |
+
- [graphrag](https://github.com/microsoft/graphrag)
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import re
|
10 |
+
from typing import Any, Callable
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import tiktoken
|
13 |
+
|
14 |
+
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
|
15 |
+
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
16 |
+
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
17 |
+
from rag.llm.chat_model import Base as CompletionLLM
|
18 |
+
import networkx as nx
|
19 |
+
from rag.utils import num_tokens_from_string
|
20 |
+
|
21 |
+
DEFAULT_TUPLE_DELIMITER = "<|>"
|
22 |
+
DEFAULT_RECORD_DELIMITER = "##"
|
23 |
+
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class GraphExtractionResult:
|
28 |
+
"""Unipartite graph extraction result class definition."""
|
29 |
+
|
30 |
+
output: nx.Graph
|
31 |
+
source_docs: dict[Any, Any]
|
32 |
+
|
33 |
+
|
34 |
+
class GraphExtractor(Extractor):
|
35 |
+
"""Unipartite graph extractor class definition."""
|
36 |
+
|
37 |
+
_join_descriptions: bool
|
38 |
+
_tuple_delimiter_key: str
|
39 |
+
_record_delimiter_key: str
|
40 |
+
_entity_types_key: str
|
41 |
+
_input_text_key: str
|
42 |
+
_completion_delimiter_key: str
|
43 |
+
_entity_name_key: str
|
44 |
+
_input_descriptions_key: str
|
45 |
+
_extraction_prompt: str
|
46 |
+
_summarization_prompt: str
|
47 |
+
_loop_args: dict[str, Any]
|
48 |
+
_max_gleanings: int
|
49 |
+
_on_error: ErrorHandlerFn
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
llm_invoker: CompletionLLM,
|
54 |
+
language: str | None = "English",
|
55 |
+
entity_types: list[str] | None = None,
|
56 |
+
get_entity: Callable | None = None,
|
57 |
+
set_entity: Callable | None = None,
|
58 |
+
get_relation: Callable | None = None,
|
59 |
+
set_relation: Callable | None = None,
|
60 |
+
tuple_delimiter_key: str | None = None,
|
61 |
+
record_delimiter_key: str | None = None,
|
62 |
+
input_text_key: str | None = None,
|
63 |
+
entity_types_key: str | None = None,
|
64 |
+
completion_delimiter_key: str | None = None,
|
65 |
+
join_descriptions=True,
|
66 |
+
max_gleanings: int | None = None,
|
67 |
+
on_error: ErrorHandlerFn | None = None,
|
68 |
+
):
|
69 |
+
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
|
70 |
+
"""Init method definition."""
|
71 |
+
# TODO: streamline construction
|
72 |
+
self._llm = llm_invoker
|
73 |
+
self._join_descriptions = join_descriptions
|
74 |
+
self._input_text_key = input_text_key or "input_text"
|
75 |
+
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
76 |
+
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
77 |
+
self._completion_delimiter_key = (
|
78 |
+
completion_delimiter_key or "completion_delimiter"
|
79 |
+
)
|
80 |
+
self._entity_types_key = entity_types_key or "entity_types"
|
81 |
+
self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
|
82 |
+
self._max_gleanings = (
|
83 |
+
max_gleanings
|
84 |
+
if max_gleanings is not None
|
85 |
+
else ENTITY_EXTRACTION_MAX_GLEANINGS
|
86 |
+
)
|
87 |
+
self._on_error = on_error or (lambda _e, _s, _d: None)
|
88 |
+
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
|
89 |
+
|
90 |
+
# Construct the looping arguments
|
91 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
92 |
+
yes = encoding.encode("YES")
|
93 |
+
no = encoding.encode("NO")
|
94 |
+
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
95 |
+
|
96 |
+
# Wire defaults into the prompt variables
|
97 |
+
self._prompt_variables = {
|
98 |
+
"entity_types": entity_types,
|
99 |
+
self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
|
100 |
+
self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
|
101 |
+
self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
|
102 |
+
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
|
103 |
+
}
|
104 |
+
|
105 |
+
def _process_single_content(self,
|
106 |
+
chunk_key_dp: tuple[str, str]
|
107 |
+
):
|
108 |
+
token_count = 0
|
109 |
+
|
110 |
+
chunk_key = chunk_key_dp[0]
|
111 |
+
content = chunk_key_dp[1]
|
112 |
+
variables = {
|
113 |
+
**self._prompt_variables,
|
114 |
+
self._input_text_key: content,
|
115 |
+
}
|
116 |
+
try:
|
117 |
+
gen_conf = {"temperature": 0.3}
|
118 |
+
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
119 |
+
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
120 |
+
token_count += num_tokens_from_string(hint_prompt + response)
|
121 |
+
|
122 |
+
results = response or ""
|
123 |
+
history = [{"role": "system", "content": hint_prompt}, {"role": "assistant", "content": response}]
|
124 |
+
|
125 |
+
# Repeat to ensure we maximize entity count
|
126 |
+
for i in range(self._max_gleanings):
|
127 |
+
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
128 |
+
history.append({"role": "user", "content": text})
|
129 |
+
response = self._chat("", history, gen_conf)
|
130 |
+
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
131 |
+
results += response or ""
|
132 |
+
|
133 |
+
# if this is the final glean, don't bother updating the continuation flag
|
134 |
+
if i >= self._max_gleanings - 1:
|
135 |
+
break
|
136 |
+
history.append({"role": "assistant", "content": response})
|
137 |
+
history.append({"role": "user", "content": LOOP_PROMPT})
|
138 |
+
continuation = self._chat("", history, self._loop_args)
|
139 |
+
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
140 |
+
if continuation != "YES":
|
141 |
+
break
|
142 |
+
|
143 |
+
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
144 |
+
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
145 |
+
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
146 |
+
records = [r for r in records if r.strip()]
|
147 |
+
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
148 |
+
return maybe_nodes, maybe_edges, token_count
|
149 |
+
except Exception as e:
|
150 |
+
logging.exception("error extracting graph")
|
151 |
+
return e, None, None
|
152 |
+
|
153 |
+
|
154 |
+
|
graphrag/{graph_prompt.py → general/graph_prompt.py}
RENAMED
@@ -106,4 +106,19 @@ Text: {input_text}
|
|
106 |
Output:"""
|
107 |
|
108 |
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
|
109 |
-
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
Output:"""
|
107 |
|
108 |
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
|
109 |
+
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
|
110 |
+
|
111 |
+
SUMMARIZE_DESCRIPTIONS_PROMPT = """
|
112 |
+
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
113 |
+
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
114 |
+
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
115 |
+
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
116 |
+
Make sure it is written in third person, and include the entity names so we the have full context.
|
117 |
+
Use {language} as output language.
|
118 |
+
|
119 |
+
#######
|
120 |
+
-Data-
|
121 |
+
Entities: {entity_name}
|
122 |
+
Description List: {description_list}
|
123 |
+
#######
|
124 |
+
"""
|
graphrag/general/index.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import json
|
17 |
+
import logging
|
18 |
+
from functools import reduce, partial
|
19 |
+
import networkx as nx
|
20 |
+
|
21 |
+
from api import settings
|
22 |
+
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
23 |
+
from graphrag.entity_resolution import EntityResolution
|
24 |
+
from graphrag.general.extractor import Extractor
|
25 |
+
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
|
26 |
+
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
|
27 |
+
chunk_id, update_nodes_pagerank_nhop_neighbour
|
28 |
+
from rag.nlp import rag_tokenizer, search
|
29 |
+
from rag.utils.redis_conn import RedisDistributedLock
|
30 |
+
|
31 |
+
|
32 |
+
class Dealer:
|
33 |
+
def __init__(self,
|
34 |
+
extractor: Extractor,
|
35 |
+
tenant_id: str,
|
36 |
+
kb_id: str,
|
37 |
+
llm_bdl,
|
38 |
+
chunks: list[tuple[str, str]],
|
39 |
+
language,
|
40 |
+
entity_types=DEFAULT_ENTITY_TYPES,
|
41 |
+
embed_bdl=None,
|
42 |
+
callback=None
|
43 |
+
):
|
44 |
+
docids = list(set([docid for docid,_ in chunks]))
|
45 |
+
self.llm_bdl = llm_bdl
|
46 |
+
self.embed_bdl = embed_bdl
|
47 |
+
ext = extractor(self.llm_bdl, language=language,
|
48 |
+
entity_types=entity_types,
|
49 |
+
get_entity=partial(get_entity, tenant_id, kb_id),
|
50 |
+
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
51 |
+
get_relation=partial(get_relation, tenant_id, kb_id),
|
52 |
+
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
|
53 |
+
)
|
54 |
+
ents, rels = ext(chunks, callback)
|
55 |
+
self.graph = nx.Graph()
|
56 |
+
for en in ents:
|
57 |
+
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
|
58 |
+
|
59 |
+
for rel in rels:
|
60 |
+
self.graph.add_edge(
|
61 |
+
rel["src_id"],
|
62 |
+
rel["tgt_id"],
|
63 |
+
weight=rel["weight"],
|
64 |
+
#description=rel["description"]
|
65 |
+
)
|
66 |
+
|
67 |
+
with RedisDistributedLock(kb_id, 60*60):
|
68 |
+
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
|
69 |
+
if old_graph is not None:
|
70 |
+
logging.info("Merge with an exiting graph...................")
|
71 |
+
self.graph = reduce(graph_merge, [old_graph, self.graph])
|
72 |
+
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
73 |
+
if old_doc_ids:
|
74 |
+
docids.extend(old_doc_ids)
|
75 |
+
docids = list(set(docids))
|
76 |
+
set_graph(tenant_id, kb_id, self.graph, docids)
|
77 |
+
|
78 |
+
|
79 |
+
class WithResolution(Dealer):
|
80 |
+
def __init__(self,
|
81 |
+
tenant_id: str,
|
82 |
+
kb_id: str,
|
83 |
+
llm_bdl,
|
84 |
+
embed_bdl=None,
|
85 |
+
callback=None
|
86 |
+
):
|
87 |
+
self.llm_bdl = llm_bdl
|
88 |
+
self.embed_bdl = embed_bdl
|
89 |
+
|
90 |
+
with RedisDistributedLock(kb_id, 60*60):
|
91 |
+
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
92 |
+
if not self.graph:
|
93 |
+
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
94 |
+
if callback:
|
95 |
+
callback(-1, msg="Faild to fetch the graph.")
|
96 |
+
return
|
97 |
+
|
98 |
+
if callback:
|
99 |
+
callback(msg="Fetch the existing graph.")
|
100 |
+
er = EntityResolution(self.llm_bdl,
|
101 |
+
get_entity=partial(get_entity, tenant_id, kb_id),
|
102 |
+
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
103 |
+
get_relation=partial(get_relation, tenant_id, kb_id),
|
104 |
+
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
105 |
+
reso = er(self.graph)
|
106 |
+
self.graph = reso.graph
|
107 |
+
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
108 |
+
if callback:
|
109 |
+
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
110 |
+
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
111 |
+
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
112 |
+
|
113 |
+
settings.docStoreConn.delete({
|
114 |
+
"knowledge_graph_kwd": "relation",
|
115 |
+
"kb_id": kb_id,
|
116 |
+
"from_entity_kwd": reso.removed_entities
|
117 |
+
}, search.index_name(tenant_id), kb_id)
|
118 |
+
settings.docStoreConn.delete({
|
119 |
+
"knowledge_graph_kwd": "relation",
|
120 |
+
"kb_id": kb_id,
|
121 |
+
"to_entity_kwd": reso.removed_entities
|
122 |
+
}, search.index_name(tenant_id), kb_id)
|
123 |
+
settings.docStoreConn.delete({
|
124 |
+
"knowledge_graph_kwd": "entity",
|
125 |
+
"kb_id": kb_id,
|
126 |
+
"entity_kwd": reso.removed_entities
|
127 |
+
}, search.index_name(tenant_id), kb_id)
|
128 |
+
|
129 |
+
|
130 |
+
class WithCommunity(Dealer):
|
131 |
+
def __init__(self,
|
132 |
+
tenant_id: str,
|
133 |
+
kb_id: str,
|
134 |
+
llm_bdl,
|
135 |
+
embed_bdl=None,
|
136 |
+
callback=None
|
137 |
+
):
|
138 |
+
|
139 |
+
self.community_structure = None
|
140 |
+
self.community_reports = None
|
141 |
+
self.llm_bdl = llm_bdl
|
142 |
+
self.embed_bdl = embed_bdl
|
143 |
+
|
144 |
+
with RedisDistributedLock(kb_id, 60*60):
|
145 |
+
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
146 |
+
if not self.graph:
|
147 |
+
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
148 |
+
if callback:
|
149 |
+
callback(-1, msg="Faild to fetch the graph.")
|
150 |
+
return
|
151 |
+
if callback:
|
152 |
+
callback(msg="Fetch the existing graph.")
|
153 |
+
|
154 |
+
cr = CommunityReportsExtractor(self.llm_bdl,
|
155 |
+
get_entity=partial(get_entity, tenant_id, kb_id),
|
156 |
+
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
157 |
+
get_relation=partial(get_relation, tenant_id, kb_id),
|
158 |
+
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
159 |
+
cr = cr(self.graph, callback=callback)
|
160 |
+
self.community_structure = cr.structured_output
|
161 |
+
self.community_reports = cr.output
|
162 |
+
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
163 |
+
|
164 |
+
if callback:
|
165 |
+
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
166 |
+
|
167 |
+
settings.docStoreConn.delete({
|
168 |
+
"knowledge_graph_kwd": "community_report",
|
169 |
+
"kb_id": kb_id
|
170 |
+
}, search.index_name(tenant_id), kb_id)
|
171 |
+
|
172 |
+
for stru, rep in zip(self.community_structure, self.community_reports):
|
173 |
+
obj = {
|
174 |
+
"report": rep,
|
175 |
+
"evidences": "\n".join([f["explanation"] for f in stru["findings"]])
|
176 |
+
}
|
177 |
+
chunk = {
|
178 |
+
"docnm_kwd": stru["title"],
|
179 |
+
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
180 |
+
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
181 |
+
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
|
182 |
+
"knowledge_graph_kwd": "community_report",
|
183 |
+
"weight_flt": stru["weight"],
|
184 |
+
"entities_kwd": stru["entities"],
|
185 |
+
"important_kwd": stru["entities"],
|
186 |
+
"kb_id": kb_id,
|
187 |
+
"source_id": doc_ids,
|
188 |
+
"available_int": 0
|
189 |
+
}
|
190 |
+
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
191 |
+
#try:
|
192 |
+
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
|
193 |
+
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
194 |
+
#except Exception as e:
|
195 |
+
# logging.exception(f"Fail to embed entity relation: {e}")
|
196 |
+
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
197 |
+
|
graphrag/{leiden.py → general/leiden.py}
RENAMED
@@ -10,7 +10,6 @@ import html
|
|
10 |
from typing import Any, cast
|
11 |
from graspologic.partition import hierarchical_leiden
|
12 |
from graspologic.utils import largest_connected_component
|
13 |
-
|
14 |
import networkx as nx
|
15 |
from networkx import is_empty
|
16 |
|
@@ -130,6 +129,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
|
130 |
if not weights:
|
131 |
continue
|
132 |
max_weight = max(weights)
|
|
|
|
|
133 |
for _, comm in result.items():
|
134 |
comm["weight"] /= max_weight
|
135 |
|
|
|
10 |
from typing import Any, cast
|
11 |
from graspologic.partition import hierarchical_leiden
|
12 |
from graspologic.utils import largest_connected_component
|
|
|
13 |
import networkx as nx
|
14 |
from networkx import is_empty
|
15 |
|
|
|
129 |
if not weights:
|
130 |
continue
|
131 |
max_weight = max(weights)
|
132 |
+
if max_weight == 0:
|
133 |
+
continue
|
134 |
for _, comm in result.items():
|
135 |
comm["weight"] /= max_weight
|
136 |
|
graphrag/{mind_map_extractor.py → general/mind_map_extractor.py}
RENAMED
@@ -23,8 +23,8 @@ from typing import Any
|
|
23 |
from concurrent.futures import ThreadPoolExecutor
|
24 |
from dataclasses import dataclass
|
25 |
|
26 |
-
from graphrag.extractor import Extractor
|
27 |
-
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
28 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
29 |
from rag.llm.chat_model import Base as CompletionLLM
|
30 |
import markdown_to_json
|
|
|
23 |
from concurrent.futures import ThreadPoolExecutor
|
24 |
from dataclasses import dataclass
|
25 |
|
26 |
+
from graphrag.general.extractor import Extractor
|
27 |
+
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
28 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
29 |
from rag.llm.chat_model import Base as CompletionLLM
|
30 |
import markdown_to_json
|
graphrag/{mind_map_prompt.py → general/mind_map_prompt.py}
RENAMED
File without changes
|
graphrag/general/smoke.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import json
|
19 |
+
|
20 |
+
import networkx as nx
|
21 |
+
|
22 |
+
from api import settings
|
23 |
+
from api.db import LLMType
|
24 |
+
from api.db.services.document_service import DocumentService
|
25 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
+
from api.db.services.llm_service import LLMBundle
|
27 |
+
from api.db.services.user_service import TenantService
|
28 |
+
from graphrag.general.index import WithCommunity, Dealer, WithResolution
|
29 |
+
from graphrag.light.graph_extractor import GraphExtractor
|
30 |
+
from rag.utils.redis_conn import RedisDistributedLock
|
31 |
+
|
32 |
+
settings.init_settings()
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
37 |
+
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
38 |
+
args = parser.parse_args()
|
39 |
+
e, doc = DocumentService.get_by_id(args.doc_id)
|
40 |
+
if not e:
|
41 |
+
raise LookupError("Document not found.")
|
42 |
+
kb_id = doc.kb_id
|
43 |
+
|
44 |
+
chunks = [d["content_with_weight"] for d in
|
45 |
+
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
|
46 |
+
fields=["content_with_weight"])]
|
47 |
+
chunks = [("x", c) for c in chunks]
|
48 |
+
|
49 |
+
RedisDistributedLock.clean_lock(kb_id)
|
50 |
+
|
51 |
+
_, tenant = TenantService.get_by_id(args.tenant_id)
|
52 |
+
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
53 |
+
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
54 |
+
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
55 |
+
|
56 |
+
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
57 |
+
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
58 |
+
|
59 |
+
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
60 |
+
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
61 |
+
|
62 |
+
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
|
63 |
+
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
|
graphrag/graph_extractor.py
DELETED
@@ -1,322 +0,0 @@
|
|
1 |
-
# Copyright (c) 2024 Microsoft Corporation.
|
2 |
-
# Licensed under the MIT License
|
3 |
-
"""
|
4 |
-
Reference:
|
5 |
-
- [graphrag](https://github.com/microsoft/graphrag)
|
6 |
-
"""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import numbers
|
10 |
-
import re
|
11 |
-
import traceback
|
12 |
-
from typing import Any, Callable, Mapping
|
13 |
-
from dataclasses import dataclass
|
14 |
-
import tiktoken
|
15 |
-
|
16 |
-
from graphrag.extractor import Extractor
|
17 |
-
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
18 |
-
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
19 |
-
from rag.llm.chat_model import Base as CompletionLLM
|
20 |
-
import networkx as nx
|
21 |
-
from rag.utils import num_tokens_from_string
|
22 |
-
from timeit import default_timer as timer
|
23 |
-
|
24 |
-
DEFAULT_TUPLE_DELIMITER = "<|>"
|
25 |
-
DEFAULT_RECORD_DELIMITER = "##"
|
26 |
-
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
27 |
-
DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
|
28 |
-
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
|
29 |
-
|
30 |
-
|
31 |
-
@dataclass
|
32 |
-
class GraphExtractionResult:
|
33 |
-
"""Unipartite graph extraction result class definition."""
|
34 |
-
|
35 |
-
output: nx.Graph
|
36 |
-
source_docs: dict[Any, Any]
|
37 |
-
|
38 |
-
|
39 |
-
class GraphExtractor(Extractor):
|
40 |
-
"""Unipartite graph extractor class definition."""
|
41 |
-
|
42 |
-
_join_descriptions: bool
|
43 |
-
_tuple_delimiter_key: str
|
44 |
-
_record_delimiter_key: str
|
45 |
-
_entity_types_key: str
|
46 |
-
_input_text_key: str
|
47 |
-
_completion_delimiter_key: str
|
48 |
-
_entity_name_key: str
|
49 |
-
_input_descriptions_key: str
|
50 |
-
_extraction_prompt: str
|
51 |
-
_summarization_prompt: str
|
52 |
-
_loop_args: dict[str, Any]
|
53 |
-
_max_gleanings: int
|
54 |
-
_on_error: ErrorHandlerFn
|
55 |
-
|
56 |
-
def __init__(
|
57 |
-
self,
|
58 |
-
llm_invoker: CompletionLLM,
|
59 |
-
prompt: str | None = None,
|
60 |
-
tuple_delimiter_key: str | None = None,
|
61 |
-
record_delimiter_key: str | None = None,
|
62 |
-
input_text_key: str | None = None,
|
63 |
-
entity_types_key: str | None = None,
|
64 |
-
completion_delimiter_key: str | None = None,
|
65 |
-
join_descriptions=True,
|
66 |
-
encoding_model: str | None = None,
|
67 |
-
max_gleanings: int | None = None,
|
68 |
-
on_error: ErrorHandlerFn | None = None,
|
69 |
-
):
|
70 |
-
"""Init method definition."""
|
71 |
-
# TODO: streamline construction
|
72 |
-
self._llm = llm_invoker
|
73 |
-
self._join_descriptions = join_descriptions
|
74 |
-
self._input_text_key = input_text_key or "input_text"
|
75 |
-
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
76 |
-
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
77 |
-
self._completion_delimiter_key = (
|
78 |
-
completion_delimiter_key or "completion_delimiter"
|
79 |
-
)
|
80 |
-
self._entity_types_key = entity_types_key or "entity_types"
|
81 |
-
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
|
82 |
-
self._max_gleanings = (
|
83 |
-
max_gleanings
|
84 |
-
if max_gleanings is not None
|
85 |
-
else ENTITY_EXTRACTION_MAX_GLEANINGS
|
86 |
-
)
|
87 |
-
self._on_error = on_error or (lambda _e, _s, _d: None)
|
88 |
-
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
|
89 |
-
|
90 |
-
# Construct the looping arguments
|
91 |
-
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
92 |
-
yes = encoding.encode("YES")
|
93 |
-
no = encoding.encode("NO")
|
94 |
-
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
95 |
-
|
96 |
-
def __call__(
|
97 |
-
self, texts: list[str],
|
98 |
-
prompt_variables: dict[str, Any] | None = None,
|
99 |
-
callback: Callable | None = None
|
100 |
-
) -> GraphExtractionResult:
|
101 |
-
"""Call method definition."""
|
102 |
-
if prompt_variables is None:
|
103 |
-
prompt_variables = {}
|
104 |
-
all_records: dict[int, str] = {}
|
105 |
-
source_doc_map: dict[int, str] = {}
|
106 |
-
|
107 |
-
# Wire defaults into the prompt variables
|
108 |
-
prompt_variables = {
|
109 |
-
**prompt_variables,
|
110 |
-
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
111 |
-
or DEFAULT_TUPLE_DELIMITER,
|
112 |
-
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
113 |
-
or DEFAULT_RECORD_DELIMITER,
|
114 |
-
self._completion_delimiter_key: prompt_variables.get(
|
115 |
-
self._completion_delimiter_key
|
116 |
-
)
|
117 |
-
or DEFAULT_COMPLETION_DELIMITER,
|
118 |
-
self._entity_types_key: ",".join(
|
119 |
-
prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
|
120 |
-
),
|
121 |
-
}
|
122 |
-
|
123 |
-
st = timer()
|
124 |
-
total = len(texts)
|
125 |
-
total_token_count = 0
|
126 |
-
for doc_index, text in enumerate(texts):
|
127 |
-
try:
|
128 |
-
# Invoke the entity extraction
|
129 |
-
result, token_count = self._process_document(text, prompt_variables)
|
130 |
-
source_doc_map[doc_index] = text
|
131 |
-
all_records[doc_index] = result
|
132 |
-
total_token_count += token_count
|
133 |
-
if callback:
|
134 |
-
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
135 |
-
except Exception as e:
|
136 |
-
if callback:
|
137 |
-
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
|
138 |
-
logging.exception("error extracting graph")
|
139 |
-
self._on_error(
|
140 |
-
e,
|
141 |
-
traceback.format_exc(),
|
142 |
-
{
|
143 |
-
"doc_index": doc_index,
|
144 |
-
"text": text,
|
145 |
-
},
|
146 |
-
)
|
147 |
-
|
148 |
-
output = self._process_results(
|
149 |
-
all_records,
|
150 |
-
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
|
151 |
-
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
|
152 |
-
)
|
153 |
-
|
154 |
-
return GraphExtractionResult(
|
155 |
-
output=output,
|
156 |
-
source_docs=source_doc_map,
|
157 |
-
)
|
158 |
-
|
159 |
-
def _process_document(
|
160 |
-
self, text: str, prompt_variables: dict[str, str]
|
161 |
-
) -> str:
|
162 |
-
variables = {
|
163 |
-
**prompt_variables,
|
164 |
-
self._input_text_key: text,
|
165 |
-
}
|
166 |
-
token_count = 0
|
167 |
-
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
168 |
-
gen_conf = {"temperature": 0.3}
|
169 |
-
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
170 |
-
token_count = num_tokens_from_string(text + response)
|
171 |
-
|
172 |
-
results = response or ""
|
173 |
-
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
174 |
-
|
175 |
-
# Repeat to ensure we maximize entity count
|
176 |
-
for i in range(self._max_gleanings):
|
177 |
-
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
178 |
-
history.append({"role": "user", "content": text})
|
179 |
-
response = self._chat("", history, gen_conf)
|
180 |
-
results += response or ""
|
181 |
-
|
182 |
-
# if this is the final glean, don't bother updating the continuation flag
|
183 |
-
if i >= self._max_gleanings - 1:
|
184 |
-
break
|
185 |
-
history.append({"role": "assistant", "content": response})
|
186 |
-
history.append({"role": "user", "content": LOOP_PROMPT})
|
187 |
-
continuation = self._chat("", history, self._loop_args)
|
188 |
-
if continuation != "YES":
|
189 |
-
break
|
190 |
-
|
191 |
-
return results, token_count
|
192 |
-
|
193 |
-
def _process_results(
|
194 |
-
self,
|
195 |
-
results: dict[int, str],
|
196 |
-
tuple_delimiter: str,
|
197 |
-
record_delimiter: str,
|
198 |
-
) -> nx.Graph:
|
199 |
-
"""Parse the result string to create an undirected unipartite graph.
|
200 |
-
|
201 |
-
Args:
|
202 |
-
- results - dict of results from the extraction chain
|
203 |
-
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
|
204 |
-
- record_delimiter - delimiter between records, default is '##'
|
205 |
-
Returns:
|
206 |
-
- output - unipartite graph in graphML format
|
207 |
-
"""
|
208 |
-
graph = nx.Graph()
|
209 |
-
for source_doc_id, extracted_data in results.items():
|
210 |
-
records = [r.strip() for r in extracted_data.split(record_delimiter)]
|
211 |
-
|
212 |
-
for record in records:
|
213 |
-
record = re.sub(r"^\(|\)$", "", record.strip())
|
214 |
-
record_attributes = record.split(tuple_delimiter)
|
215 |
-
|
216 |
-
if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
|
217 |
-
# add this record as a node in the G
|
218 |
-
entity_name = clean_str(record_attributes[1].upper())
|
219 |
-
entity_type = clean_str(record_attributes[2].upper())
|
220 |
-
entity_description = clean_str(record_attributes[3])
|
221 |
-
|
222 |
-
if entity_name in graph.nodes():
|
223 |
-
node = graph.nodes[entity_name]
|
224 |
-
if self._join_descriptions:
|
225 |
-
node["description"] = "\n".join(
|
226 |
-
list({
|
227 |
-
*_unpack_descriptions(node),
|
228 |
-
entity_description,
|
229 |
-
})
|
230 |
-
)
|
231 |
-
else:
|
232 |
-
if len(entity_description) > len(node["description"]):
|
233 |
-
node["description"] = entity_description
|
234 |
-
node["source_id"] = ", ".join(
|
235 |
-
list({
|
236 |
-
*_unpack_source_ids(node),
|
237 |
-
str(source_doc_id),
|
238 |
-
})
|
239 |
-
)
|
240 |
-
node["entity_type"] = (
|
241 |
-
entity_type if entity_type != "" else node["entity_type"]
|
242 |
-
)
|
243 |
-
else:
|
244 |
-
graph.add_node(
|
245 |
-
entity_name,
|
246 |
-
entity_type=entity_type,
|
247 |
-
description=entity_description,
|
248 |
-
source_id=str(source_doc_id),
|
249 |
-
weight=1
|
250 |
-
)
|
251 |
-
|
252 |
-
if (
|
253 |
-
record_attributes[0] == '"relationship"'
|
254 |
-
and len(record_attributes) >= 5
|
255 |
-
):
|
256 |
-
# add this record as edge
|
257 |
-
source = clean_str(record_attributes[1].upper())
|
258 |
-
target = clean_str(record_attributes[2].upper())
|
259 |
-
edge_description = clean_str(record_attributes[3])
|
260 |
-
edge_source_id = clean_str(str(source_doc_id))
|
261 |
-
weight = (
|
262 |
-
float(record_attributes[-1])
|
263 |
-
if isinstance(record_attributes[-1], numbers.Number)
|
264 |
-
else 1.0
|
265 |
-
)
|
266 |
-
if source not in graph.nodes():
|
267 |
-
graph.add_node(
|
268 |
-
source,
|
269 |
-
entity_type="",
|
270 |
-
description="",
|
271 |
-
source_id=edge_source_id,
|
272 |
-
weight=1
|
273 |
-
)
|
274 |
-
if target not in graph.nodes():
|
275 |
-
graph.add_node(
|
276 |
-
target,
|
277 |
-
entity_type="",
|
278 |
-
description="",
|
279 |
-
source_id=edge_source_id,
|
280 |
-
weight=1
|
281 |
-
)
|
282 |
-
if graph.has_edge(source, target):
|
283 |
-
edge_data = graph.get_edge_data(source, target)
|
284 |
-
if edge_data is not None:
|
285 |
-
weight += edge_data["weight"]
|
286 |
-
if self._join_descriptions:
|
287 |
-
edge_description = "\n".join(
|
288 |
-
list({
|
289 |
-
*_unpack_descriptions(edge_data),
|
290 |
-
edge_description,
|
291 |
-
})
|
292 |
-
)
|
293 |
-
edge_source_id = ", ".join(
|
294 |
-
list({
|
295 |
-
*_unpack_source_ids(edge_data),
|
296 |
-
str(source_doc_id),
|
297 |
-
})
|
298 |
-
)
|
299 |
-
graph.add_edge(
|
300 |
-
source,
|
301 |
-
target,
|
302 |
-
weight=weight,
|
303 |
-
description=edge_description,
|
304 |
-
source_id=edge_source_id,
|
305 |
-
)
|
306 |
-
|
307 |
-
for node_degree in graph.degree:
|
308 |
-
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
309 |
-
return graph
|
310 |
-
|
311 |
-
|
312 |
-
def _unpack_descriptions(data: Mapping) -> list[str]:
|
313 |
-
value = data.get("description", None)
|
314 |
-
return [] if value is None else value.split("\n")
|
315 |
-
|
316 |
-
|
317 |
-
def _unpack_source_ids(data: Mapping) -> list[str]:
|
318 |
-
value = data.get("source_id", None)
|
319 |
-
return [] if value is None else value.split(", ")
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphrag/index.py
DELETED
@@ -1,153 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
#
|
16 |
-
import logging
|
17 |
-
import os
|
18 |
-
from concurrent.futures import ThreadPoolExecutor
|
19 |
-
import json
|
20 |
-
from functools import reduce
|
21 |
-
import networkx as nx
|
22 |
-
from api.db import LLMType
|
23 |
-
from api.db.services.llm_service import LLMBundle
|
24 |
-
from api.db.services.user_service import TenantService
|
25 |
-
from graphrag.community_reports_extractor import CommunityReportsExtractor
|
26 |
-
from graphrag.entity_resolution import EntityResolution
|
27 |
-
from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
|
28 |
-
from graphrag.mind_map_extractor import MindMapExtractor
|
29 |
-
from rag.nlp import rag_tokenizer
|
30 |
-
from rag.utils import num_tokens_from_string
|
31 |
-
|
32 |
-
|
33 |
-
def graph_merge(g1, g2):
|
34 |
-
g = g2.copy()
|
35 |
-
for n, attr in g1.nodes(data=True):
|
36 |
-
if n not in g2.nodes():
|
37 |
-
g.add_node(n, **attr)
|
38 |
-
continue
|
39 |
-
|
40 |
-
g.nodes[n]["weight"] += 1
|
41 |
-
if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
|
42 |
-
g.nodes[n]["description"] += "\n" + attr["description"]
|
43 |
-
|
44 |
-
for source, target, attr in g1.edges(data=True):
|
45 |
-
if g.has_edge(source, target):
|
46 |
-
g[source][target].update({"weight": attr["weight"]+1})
|
47 |
-
continue
|
48 |
-
g.add_edge(source, target, **attr)
|
49 |
-
|
50 |
-
for node_degree in g.degree:
|
51 |
-
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
52 |
-
return g
|
53 |
-
|
54 |
-
|
55 |
-
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
|
56 |
-
_, tenant = TenantService.get_by_id(tenant_id)
|
57 |
-
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
|
58 |
-
ext = GraphExtractor(llm_bdl)
|
59 |
-
left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
|
60 |
-
left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
|
61 |
-
|
62 |
-
assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
|
63 |
-
|
64 |
-
BATCH_SIZE=4
|
65 |
-
texts, graphs = [], []
|
66 |
-
cnt = 0
|
67 |
-
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
|
68 |
-
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
69 |
-
threads = []
|
70 |
-
for i in range(len(chunks)):
|
71 |
-
tkn_cnt = num_tokens_from_string(chunks[i])
|
72 |
-
if cnt+tkn_cnt >= left_token_count and texts:
|
73 |
-
for b in range(0, len(texts), BATCH_SIZE):
|
74 |
-
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
75 |
-
texts = []
|
76 |
-
cnt = 0
|
77 |
-
texts.append(chunks[i])
|
78 |
-
cnt += tkn_cnt
|
79 |
-
if texts:
|
80 |
-
for b in range(0, len(texts), BATCH_SIZE):
|
81 |
-
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
82 |
-
|
83 |
-
callback(0.5, "Extracting entities.")
|
84 |
-
graphs = []
|
85 |
-
for i, _ in enumerate(threads):
|
86 |
-
graphs.append(_.result().output)
|
87 |
-
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
88 |
-
|
89 |
-
graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
|
90 |
-
er = EntityResolution(llm_bdl)
|
91 |
-
graph = er(graph).output
|
92 |
-
|
93 |
-
_chunks = chunks
|
94 |
-
chunks = []
|
95 |
-
for n, attr in graph.nodes(data=True):
|
96 |
-
if attr.get("rank", 0) == 0:
|
97 |
-
logging.debug(f"Ignore entity: {n}")
|
98 |
-
continue
|
99 |
-
chunk = {
|
100 |
-
"name_kwd": n,
|
101 |
-
"important_kwd": [n],
|
102 |
-
"title_tks": rag_tokenizer.tokenize(n),
|
103 |
-
"content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
|
104 |
-
"content_ltks": rag_tokenizer.tokenize(attr["description"]),
|
105 |
-
"knowledge_graph_kwd": "entity",
|
106 |
-
"rank_int": attr["rank"],
|
107 |
-
"weight_int": attr["weight"]
|
108 |
-
}
|
109 |
-
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
110 |
-
chunks.append(chunk)
|
111 |
-
|
112 |
-
callback(0.6, "Extracting community reports.")
|
113 |
-
cr = CommunityReportsExtractor(llm_bdl)
|
114 |
-
cr = cr(graph, callback=callback)
|
115 |
-
for community, desc in zip(cr.structured_output, cr.output):
|
116 |
-
chunk = {
|
117 |
-
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
118 |
-
"content_with_weight": desc,
|
119 |
-
"content_ltks": rag_tokenizer.tokenize(desc),
|
120 |
-
"knowledge_graph_kwd": "community_report",
|
121 |
-
"weight_flt": community["weight"],
|
122 |
-
"entities_kwd": community["entities"],
|
123 |
-
"important_kwd": community["entities"]
|
124 |
-
}
|
125 |
-
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
126 |
-
chunks.append(chunk)
|
127 |
-
|
128 |
-
chunks.append(
|
129 |
-
{
|
130 |
-
"content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
|
131 |
-
"knowledge_graph_kwd": "graph"
|
132 |
-
})
|
133 |
-
|
134 |
-
callback(0.75, "Extracting mind graph.")
|
135 |
-
mindmap = MindMapExtractor(llm_bdl)
|
136 |
-
mg = mindmap(_chunks).output
|
137 |
-
if not len(mg.keys()):
|
138 |
-
return chunks
|
139 |
-
|
140 |
-
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
|
141 |
-
chunks.append(
|
142 |
-
{
|
143 |
-
"content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
|
144 |
-
"knowledge_graph_kwd": "mind_map"
|
145 |
-
})
|
146 |
-
|
147 |
-
return chunks
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphrag/light/__init__.py
ADDED
File without changes
|
graphrag/light/graph_extractor.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License
|
3 |
+
"""
|
4 |
+
Reference:
|
5 |
+
- [graphrag](https://github.com/microsoft/graphrag)
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import re
|
9 |
+
from typing import Any, Callable
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
12 |
+
from graphrag.light.graph_prompt import PROMPTS
|
13 |
+
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
|
14 |
+
from rag.llm.chat_model import Base as CompletionLLM
|
15 |
+
import networkx as nx
|
16 |
+
from rag.utils import num_tokens_from_string
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class GraphExtractionResult:
|
21 |
+
"""Unipartite graph extraction result class definition."""
|
22 |
+
|
23 |
+
output: nx.Graph
|
24 |
+
source_docs: dict[Any, Any]
|
25 |
+
|
26 |
+
|
27 |
+
class GraphExtractor(Extractor):
|
28 |
+
|
29 |
+
_max_gleanings: int
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
llm_invoker: CompletionLLM,
|
34 |
+
language: str | None = "English",
|
35 |
+
entity_types: list[str] | None = None,
|
36 |
+
get_entity: Callable | None = None,
|
37 |
+
set_entity: Callable | None = None,
|
38 |
+
get_relation: Callable | None = None,
|
39 |
+
set_relation: Callable | None = None,
|
40 |
+
example_number: int = 2,
|
41 |
+
max_gleanings: int | None = None,
|
42 |
+
):
|
43 |
+
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
|
44 |
+
"""Init method definition."""
|
45 |
+
self._max_gleanings = (
|
46 |
+
max_gleanings
|
47 |
+
if max_gleanings is not None
|
48 |
+
else ENTITY_EXTRACTION_MAX_GLEANINGS
|
49 |
+
)
|
50 |
+
self._example_number = example_number
|
51 |
+
examples = "\n".join(
|
52 |
+
PROMPTS["entity_extraction_examples"][: int(self._example_number)]
|
53 |
+
)
|
54 |
+
|
55 |
+
example_context_base = dict(
|
56 |
+
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
57 |
+
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
58 |
+
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
59 |
+
entity_types=",".join(self._entity_types),
|
60 |
+
language=self._language,
|
61 |
+
)
|
62 |
+
# add example's format
|
63 |
+
examples = examples.format(**example_context_base)
|
64 |
+
|
65 |
+
self._entity_extract_prompt = PROMPTS["entity_extraction"]
|
66 |
+
self._context_base = dict(
|
67 |
+
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
68 |
+
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
69 |
+
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
70 |
+
entity_types=",".join(self._entity_types),
|
71 |
+
examples=examples,
|
72 |
+
language=self._language,
|
73 |
+
)
|
74 |
+
|
75 |
+
self._continue_prompt = PROMPTS["entiti_continue_extraction"]
|
76 |
+
self._if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
77 |
+
|
78 |
+
self._left_token_count = llm_invoker.max_length - num_tokens_from_string(
|
79 |
+
self._entity_extract_prompt.format(
|
80 |
+
**self._context_base, input_text="{input_text}"
|
81 |
+
).format(**self._context_base, input_text="")
|
82 |
+
)
|
83 |
+
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
84 |
+
|
85 |
+
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
|
86 |
+
token_count = 0
|
87 |
+
chunk_key = chunk_key_dp[0]
|
88 |
+
content = chunk_key_dp[1]
|
89 |
+
hint_prompt = self._entity_extract_prompt.format(
|
90 |
+
**self._context_base, input_text="{input_text}"
|
91 |
+
).format(**self._context_base, input_text=content)
|
92 |
+
|
93 |
+
try:
|
94 |
+
gen_conf = {"temperature": 0.3}
|
95 |
+
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
96 |
+
token_count += num_tokens_from_string(hint_prompt + final_result)
|
97 |
+
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
98 |
+
for now_glean_index in range(self._max_gleanings):
|
99 |
+
glean_result = self._chat(self._continue_prompt, history, gen_conf)
|
100 |
+
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + glean_result + self._continue_prompt)
|
101 |
+
history += pack_user_ass_to_openai_messages(self._continue_prompt, glean_result)
|
102 |
+
final_result += glean_result
|
103 |
+
if now_glean_index == self._max_gleanings - 1:
|
104 |
+
break
|
105 |
+
|
106 |
+
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
|
107 |
+
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
108 |
+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
109 |
+
if if_loop_result != "yes":
|
110 |
+
break
|
111 |
+
|
112 |
+
records = split_string_by_multi_markers(
|
113 |
+
final_result,
|
114 |
+
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
115 |
+
)
|
116 |
+
rcds = []
|
117 |
+
for record in records:
|
118 |
+
record = re.search(r"\((.*)\)", record)
|
119 |
+
if record is None:
|
120 |
+
continue
|
121 |
+
rcds.append(record.group(1))
|
122 |
+
records = rcds
|
123 |
+
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
124 |
+
return maybe_nodes, maybe_edges, token_count
|
125 |
+
except Exception as e:
|
126 |
+
logging.exception("error extracting graph")
|
127 |
+
return e, None, None
|
graphrag/light/graph_prompt.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the MIT License
|
2 |
+
"""
|
3 |
+
Reference:
|
4 |
+
- [LightRag](https://github.com/HKUDS/LightRAG)
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
PROMPTS = {}
|
9 |
+
|
10 |
+
PROMPTS["DEFAULT_LANGUAGE"] = "English"
|
11 |
+
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
|
12 |
+
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
13 |
+
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
14 |
+
PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
15 |
+
|
16 |
+
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
|
17 |
+
|
18 |
+
PROMPTS["entity_extraction"] = """-Goal-
|
19 |
+
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
20 |
+
Use {language} as output language.
|
21 |
+
|
22 |
+
-Steps-
|
23 |
+
1. Identify all entities. For each identified entity, extract the following information:
|
24 |
+
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
|
25 |
+
- entity_type: One of the following types: [{entity_types}]
|
26 |
+
- entity_description: Comprehensive description of the entity's attributes and activities
|
27 |
+
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
|
28 |
+
|
29 |
+
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
30 |
+
For each pair of related entities, extract the following information:
|
31 |
+
- source_entity: name of the source entity, as identified in step 1
|
32 |
+
- target_entity: name of the target entity, as identified in step 1
|
33 |
+
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
34 |
+
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
|
35 |
+
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
|
36 |
+
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
|
37 |
+
|
38 |
+
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
39 |
+
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
40 |
+
|
41 |
+
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
42 |
+
|
43 |
+
5. When finished, output {completion_delimiter}
|
44 |
+
|
45 |
+
######################
|
46 |
+
-Examples-
|
47 |
+
######################
|
48 |
+
{examples}
|
49 |
+
|
50 |
+
#############################
|
51 |
+
-Real Data-
|
52 |
+
######################
|
53 |
+
Entity_types: {entity_types}
|
54 |
+
Text: {input_text}
|
55 |
+
######################
|
56 |
+
"""
|
57 |
+
|
58 |
+
PROMPTS["entity_extraction_examples"] = [
|
59 |
+
"""Example 1:
|
60 |
+
|
61 |
+
Entity_types: [person, technology, mission, organization, location]
|
62 |
+
Text:
|
63 |
+
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
|
64 |
+
|
65 |
+
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
|
66 |
+
|
67 |
+
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
|
68 |
+
|
69 |
+
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
|
70 |
+
################
|
71 |
+
Output:
|
72 |
+
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
|
73 |
+
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
|
74 |
+
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
|
75 |
+
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
|
76 |
+
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
|
77 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
|
78 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
|
79 |
+
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
|
80 |
+
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
|
81 |
+
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
|
82 |
+
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
|
83 |
+
#############################""",
|
84 |
+
"""Example 2:
|
85 |
+
|
86 |
+
Entity_types: [person, technology, mission, organization, location]
|
87 |
+
Text:
|
88 |
+
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
|
89 |
+
|
90 |
+
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
|
91 |
+
|
92 |
+
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
|
93 |
+
#############
|
94 |
+
Output:
|
95 |
+
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
|
96 |
+
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
|
97 |
+
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
|
98 |
+
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
|
99 |
+
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
|
100 |
+
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
|
101 |
+
#############################""",
|
102 |
+
"""Example 3:
|
103 |
+
|
104 |
+
Entity_types: [person, role, technology, organization, event, location, concept]
|
105 |
+
Text:
|
106 |
+
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
|
107 |
+
|
108 |
+
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
|
109 |
+
|
110 |
+
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
|
111 |
+
|
112 |
+
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
|
113 |
+
|
114 |
+
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
|
115 |
+
#############
|
116 |
+
Output:
|
117 |
+
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
|
118 |
+
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
|
119 |
+
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
|
120 |
+
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
|
121 |
+
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
|
122 |
+
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
|
123 |
+
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
|
124 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
|
125 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
|
126 |
+
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
|
127 |
+
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
|
128 |
+
#############################""",
|
129 |
+
]
|
130 |
+
|
131 |
+
PROMPTS[
|
132 |
+
"entiti_continue_extraction"
|
133 |
+
] = """MANY entities were missed in the last extraction. Add them below using the same format:
|
134 |
+
"""
|
135 |
+
|
136 |
+
PROMPTS[
|
137 |
+
"entiti_if_loop_extraction"
|
138 |
+
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
|
139 |
+
"""
|
140 |
+
|
141 |
+
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
142 |
+
|
143 |
+
PROMPTS["rag_response"] = """---Role---
|
144 |
+
|
145 |
+
You are a helpful assistant responding to questions about data in the tables provided.
|
146 |
+
|
147 |
+
|
148 |
+
---Goal---
|
149 |
+
|
150 |
+
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
151 |
+
If you don't know the answer, just say so. Do not make anything up.
|
152 |
+
Do not include information where the supporting evidence for it is not provided.
|
153 |
+
|
154 |
+
When handling relationships with timestamps:
|
155 |
+
1. Each relationship has a "created_at" timestamp indicating when we acquired this knowledge
|
156 |
+
2. When encountering conflicting relationships, consider both the semantic content and the timestamp
|
157 |
+
3. Don't automatically prefer the most recently created relationships - use judgment based on the context
|
158 |
+
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
|
159 |
+
|
160 |
+
---Target response length and format---
|
161 |
+
|
162 |
+
{response_type}
|
163 |
+
|
164 |
+
---Data tables---
|
165 |
+
|
166 |
+
{context_data}
|
167 |
+
|
168 |
+
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown."""
|
169 |
+
|
170 |
+
PROMPTS["naive_rag_response"] = """---Role---
|
171 |
+
|
172 |
+
You are a helpful assistant responding to questions about documents provided.
|
173 |
+
|
174 |
+
|
175 |
+
---Goal---
|
176 |
+
|
177 |
+
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
178 |
+
If you don't know the answer, just say so. Do not make anything up.
|
179 |
+
Do not include information where the supporting evidence for it is not provided.
|
180 |
+
|
181 |
+
When handling content with timestamps:
|
182 |
+
1. Each piece of content has a "created_at" timestamp indicating when we acquired this knowledge
|
183 |
+
2. When encountering conflicting information, consider both the content and the timestamp
|
184 |
+
3. Don't automatically prefer the most recent content - use judgment based on the context
|
185 |
+
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
|
186 |
+
|
187 |
+
---Target response length and format---
|
188 |
+
|
189 |
+
{response_type}
|
190 |
+
|
191 |
+
---Documents---
|
192 |
+
|
193 |
+
{content_data}
|
194 |
+
|
195 |
+
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
196 |
+
"""
|
197 |
+
|
198 |
+
PROMPTS[
|
199 |
+
"similarity_check"
|
200 |
+
] = """Please analyze the similarity between these two questions:
|
201 |
+
|
202 |
+
Question 1: {original_prompt}
|
203 |
+
Question 2: {cached_prompt}
|
204 |
+
|
205 |
+
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
206 |
+
1. Whether these two questions are semantically similar
|
207 |
+
2. Whether the answer to Question 2 can be used to answer Question 1
|
208 |
+
Similarity score criteria:
|
209 |
+
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
210 |
+
- The questions have different topics
|
211 |
+
- The locations mentioned in the questions are different
|
212 |
+
- The times mentioned in the questions are different
|
213 |
+
- The specific individuals mentioned in the questions are different
|
214 |
+
- The specific events mentioned in the questions are different
|
215 |
+
- The background information in the questions is different
|
216 |
+
- The key conditions in the questions are different
|
217 |
+
1: Identical and answer can be directly reused
|
218 |
+
0.5: Partially related and answer needs modification to be used
|
219 |
+
Return only a number between 0-1, without any additional content.
|
220 |
+
"""
|
221 |
+
|
222 |
+
PROMPTS["mix_rag_response"] = """---Role---
|
223 |
+
|
224 |
+
You are a professional assistant responsible for answering questions based on knowledge graph and textual information. Please respond in the same language as the user's question.
|
225 |
+
|
226 |
+
---Goal---
|
227 |
+
|
228 |
+
Generate a concise response that summarizes relevant points from the provided information. If you don't know the answer, just say so. Do not make anything up or include information where the supporting evidence is not provided.
|
229 |
+
|
230 |
+
When handling information with timestamps:
|
231 |
+
1. Each piece of information (both relationships and content) has a "created_at" timestamp indicating when we acquired this knowledge
|
232 |
+
2. When encountering conflicting information, consider both the content/relationship and the timestamp
|
233 |
+
3. Don't automatically prefer the most recent information - use judgment based on the context
|
234 |
+
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
|
235 |
+
|
236 |
+
---Data Sources---
|
237 |
+
|
238 |
+
1. Knowledge Graph Data:
|
239 |
+
{kg_context}
|
240 |
+
|
241 |
+
2. Vector Data:
|
242 |
+
{vector_context}
|
243 |
+
|
244 |
+
---Response Requirements---
|
245 |
+
|
246 |
+
- Target format and length: {response_type}
|
247 |
+
- Use markdown formatting with appropriate section headings
|
248 |
+
- Aim to keep content around 3 paragraphs for conciseness
|
249 |
+
- Each paragraph should be under a relevant section heading
|
250 |
+
- Each section should focus on one main point or aspect of the answer
|
251 |
+
- Use clear and descriptive section titles that reflect the content
|
252 |
+
- List up to 5 most important reference sources at the end under "References", clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (VD)
|
253 |
+
Format: [KG/VD] Source content
|
254 |
+
|
255 |
+
Add sections and commentary to the response as appropriate for the length and format. If the provided information is insufficient to answer the question, clearly state that you don't know or cannot provide an answer in the same language as the user's question."""
|
graphrag/{smoke.py → light/smoke.py}
RENAMED
@@ -16,11 +16,19 @@
|
|
16 |
|
17 |
import argparse
|
18 |
import json
|
19 |
-
from
|
20 |
-
|
21 |
-
|
22 |
-
from
|
23 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
if __name__ == "__main__":
|
26 |
parser = argparse.ArgumentParser()
|
@@ -28,28 +36,23 @@ if __name__ == "__main__":
|
|
28 |
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
29 |
args = parser.parse_args()
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
graph = er(graph.output)
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
print(json.dumps(comm, ensure_ascii=False, indent=2))
|
51 |
|
52 |
-
|
53 |
-
cr = cr(graph.output)
|
54 |
-
print("------------------ COMMUNITY REPORT ----------------------\n", cr.output)
|
55 |
-
print(json.dumps(cr.structured_output, ensure_ascii=False, indent=2))
|
|
|
16 |
|
17 |
import argparse
|
18 |
import json
|
19 |
+
from api import settings
|
20 |
+
import networkx as nx
|
21 |
+
|
22 |
+
from api.db import LLMType
|
23 |
+
from api.db.services.document_service import DocumentService
|
24 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
25 |
+
from api.db.services.llm_service import LLMBundle
|
26 |
+
from api.db.services.user_service import TenantService
|
27 |
+
from graphrag.general.index import Dealer
|
28 |
+
from graphrag.light.graph_extractor import GraphExtractor
|
29 |
+
from rag.utils.redis_conn import RedisDistributedLock
|
30 |
+
|
31 |
+
settings.init_settings()
|
32 |
|
33 |
if __name__ == "__main__":
|
34 |
parser = argparse.ArgumentParser()
|
|
|
36 |
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
37 |
args = parser.parse_args()
|
38 |
|
39 |
+
e, doc = DocumentService.get_by_id(args.doc_id)
|
40 |
+
if not e:
|
41 |
+
raise LookupError("Document not found.")
|
42 |
+
kb_id = doc.kb_id
|
|
|
|
|
43 |
|
44 |
+
chunks = [d["content_with_weight"] for d in
|
45 |
+
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
|
46 |
+
fields=["content_with_weight"])]
|
47 |
+
chunks = [("x", c) for c in chunks]
|
48 |
|
49 |
+
RedisDistributedLock.clean_lock(kb_id)
|
|
|
50 |
|
51 |
+
_, tenant = TenantService.get_by_id(args.tenant_id)
|
52 |
+
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
53 |
+
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
54 |
+
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
55 |
|
56 |
+
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
|
|
57 |
|
58 |
+
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
|
|
|
|
|
graphrag/query_analyze_prompt.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the MIT License
|
2 |
+
"""
|
3 |
+
Reference:
|
4 |
+
- [LightRag](https://github.com/HKUDS/LightRAG)
|
5 |
+
- [MiniRAG](https://github.com/HKUDS/MiniRAG)
|
6 |
+
"""
|
7 |
+
PROMPTS = {}
|
8 |
+
|
9 |
+
PROMPTS["minirag_query2kwd"] = """---Role---
|
10 |
+
|
11 |
+
You are a helpful assistant tasked with identifying both answer-type and low-level keywords in the user's query.
|
12 |
+
|
13 |
+
---Goal---
|
14 |
+
|
15 |
+
Given the query, list both answer-type and low-level keywords.
|
16 |
+
answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms.
|
17 |
+
The answer_type_keywords must be selected from Answer type pool.
|
18 |
+
This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples.
|
19 |
+
|
20 |
+
---Instructions---
|
21 |
+
|
22 |
+
- Output the keywords in JSON format.
|
23 |
+
- The JSON should have three keys:
|
24 |
+
- "answer_type_keywords" for the types of the answer. In this list, the types with the highest likelihood should be placed at the forefront. No more than 3.
|
25 |
+
- "entities_from_query" for specific entities or details. It must be extracted from the query.
|
26 |
+
######################
|
27 |
+
-Examples-
|
28 |
+
######################
|
29 |
+
Example 1:
|
30 |
+
|
31 |
+
Query: "How does international trade influence global economic stability?"
|
32 |
+
Answer type pool: {{
|
33 |
+
'PERSONAL LIFE': ['FAMILY TIME', 'HOME MAINTENANCE'],
|
34 |
+
'STRATEGY': ['MARKETING PLAN', 'BUSINESS EXPANSION'],
|
35 |
+
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
|
36 |
+
'PERSON': ['JANE DOE', 'JOHN SMITH'],
|
37 |
+
'FOOD': ['PASTA', 'SUSHI'],
|
38 |
+
'EMOTION': ['HAPPINESS', 'ANGER'],
|
39 |
+
'PERSONAL EXPERIENCE': ['TRAVEL ABROAD', 'STUDYING ABROAD'],
|
40 |
+
'INTERACTION': ['TEAM MEETING', 'NETWORKING EVENT'],
|
41 |
+
'BEVERAGE': ['COFFEE', 'TEA'],
|
42 |
+
'PLAN': ['ANNUAL BUDGET', 'PROJECT TIMELINE'],
|
43 |
+
'GEO': ['NEW YORK CITY', 'SOUTH AFRICA'],
|
44 |
+
'GEAR': ['CAMPING TENT', 'CYCLING HELMET'],
|
45 |
+
'EMOJI': ['🎉', '🚀'],
|
46 |
+
'BEHAVIOR': ['POSITIVE FEEDBACK', 'NEGATIVE CRITICISM'],
|
47 |
+
'TONE': ['FORMAL', 'INFORMAL'],
|
48 |
+
'LOCATION': ['DOWNTOWN', 'SUBURBS']
|
49 |
+
}}
|
50 |
+
################
|
51 |
+
Output:
|
52 |
+
{{
|
53 |
+
"answer_type_keywords": ["STRATEGY","PERSONAL LIFE"],
|
54 |
+
"entities_from_query": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
|
55 |
+
}}
|
56 |
+
#############################
|
57 |
+
Example 2:
|
58 |
+
|
59 |
+
Query: "When was SpaceX's first rocket launch?"
|
60 |
+
Answer type pool: {{
|
61 |
+
'DATE AND TIME': ['2023-10-10 10:00', 'THIS AFTERNOON'],
|
62 |
+
'ORGANIZATION': ['GLOBAL INITIATIVES CORPORATION', 'LOCAL COMMUNITY CENTER'],
|
63 |
+
'PERSONAL LIFE': ['DAILY EXERCISE ROUTINE', 'FAMILY VACATION PLANNING'],
|
64 |
+
'STRATEGY': ['NEW PRODUCT LAUNCH', 'YEAR-END SALES BOOST'],
|
65 |
+
'SERVICE FACILITATION': ['REMOTE IT SUPPORT', 'ON-SITE TRAINING SESSIONS'],
|
66 |
+
'PERSON': ['ALEXANDER HAMILTON', 'MARIA CURIE'],
|
67 |
+
'FOOD': ['GRILLED SALMON', 'VEGETARIAN BURRITO'],
|
68 |
+
'EMOTION': ['EXCITEMENT', 'DISAPPOINTMENT'],
|
69 |
+
'PERSONAL EXPERIENCE': ['BIRTHDAY CELEBRATION', 'FIRST MARATHON'],
|
70 |
+
'INTERACTION': ['OFFICE WATER COOLER CHAT', 'ONLINE FORUM DEBATE'],
|
71 |
+
'BEVERAGE': ['ICED COFFEE', 'GREEN SMOOTHIE'],
|
72 |
+
'PLAN': ['WEEKLY MEETING SCHEDULE', 'MONTHLY BUDGET OVERVIEW'],
|
73 |
+
'GEO': ['MOUNT EVEREST BASE CAMP', 'THE GREAT BARRIER REEF'],
|
74 |
+
'GEAR': ['PROFESSIONAL CAMERA EQUIPMENT', 'OUTDOOR HIKING GEAR'],
|
75 |
+
'EMOJI': ['📅', '⏰'],
|
76 |
+
'BEHAVIOR': ['PUNCTUALITY', 'HONESTY'],
|
77 |
+
'TONE': ['CONFIDENTIAL', 'SATIRICAL'],
|
78 |
+
'LOCATION': ['CENTRAL PARK', 'DOWNTOWN LIBRARY']
|
79 |
+
}}
|
80 |
+
|
81 |
+
################
|
82 |
+
Output:
|
83 |
+
{{
|
84 |
+
"answer_type_keywords": ["DATE AND TIME", "ORGANIZATION", "PLAN"],
|
85 |
+
"entities_from_query": ["SpaceX", "Rocket launch", "Aerospace", "Power Recovery"]
|
86 |
+
|
87 |
+
}}
|
88 |
+
#############################
|
89 |
+
Example 3:
|
90 |
+
|
91 |
+
Query: "What is the role of education in reducing poverty?"
|
92 |
+
Answer type pool: {{
|
93 |
+
'PERSONAL LIFE': ['MANAGING WORK-LIFE BALANCE', 'HOME IMPROVEMENT PROJECTS'],
|
94 |
+
'STRATEGY': ['MARKETING STRATEGIES FOR Q4', 'EXPANDING INTO NEW MARKETS'],
|
95 |
+
'SERVICE FACILITATION': ['CUSTOMER SATISFACTION SURVEYS', 'STAFF RETENTION PROGRAMS'],
|
96 |
+
'PERSON': ['ALBERT EINSTEIN', 'MARIA CALLAS'],
|
97 |
+
'FOOD': ['PAN-FRIED STEAK', 'POACHED EGGS'],
|
98 |
+
'EMOTION': ['OVERWHELM', 'CONTENTMENT'],
|
99 |
+
'PERSONAL EXPERIENCE': ['LIVING ABROAD', 'STARTING A NEW JOB'],
|
100 |
+
'INTERACTION': ['SOCIAL MEDIA ENGAGEMENT', 'PUBLIC SPEAKING'],
|
101 |
+
'BEVERAGE': ['CAPPUCCINO', 'MATCHA LATTE'],
|
102 |
+
'PLAN': ['ANNUAL FITNESS GOALS', 'QUARTERLY BUSINESS REVIEW'],
|
103 |
+
'GEO': ['THE AMAZON RAINFOREST', 'THE GRAND CANYON'],
|
104 |
+
'GEAR': ['SURFING ESSENTIALS', 'CYCLING ACCESSORIES'],
|
105 |
+
'EMOJI': ['💻', '📱'],
|
106 |
+
'BEHAVIOR': ['TEAMWORK', 'LEADERSHIP'],
|
107 |
+
'TONE': ['FORMAL MEETING', 'CASUAL CONVERSATION'],
|
108 |
+
'LOCATION': ['URBAN CITY CENTER', 'RURAL COUNTRYSIDE']
|
109 |
+
}}
|
110 |
+
|
111 |
+
################
|
112 |
+
Output:
|
113 |
+
{{
|
114 |
+
"answer_type_keywords": ["STRATEGY", "PERSON"],
|
115 |
+
"entities_from_query": ["School access", "Literacy rates", "Job training", "Income inequality"]
|
116 |
+
}}
|
117 |
+
#############################
|
118 |
+
Example 4:
|
119 |
+
|
120 |
+
Query: "Where is the capital of the United States?"
|
121 |
+
Answer type pool: {{
|
122 |
+
'ORGANIZATION': ['GREENPEACE', 'RED CROSS'],
|
123 |
+
'PERSONAL LIFE': ['DAILY WORKOUT', 'HOME COOKING'],
|
124 |
+
'STRATEGY': ['FINANCIAL INVESTMENT', 'BUSINESS EXPANSION'],
|
125 |
+
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
|
126 |
+
'PERSON': ['ALBERTA SMITH', 'BENJAMIN JONES'],
|
127 |
+
'FOOD': ['PASTA CARBONARA', 'SUSHI PLATTER'],
|
128 |
+
'EMOTION': ['HAPPINESS', 'SADNESS'],
|
129 |
+
'PERSONAL EXPERIENCE': ['TRAVEL ADVENTURE', 'BOOK CLUB'],
|
130 |
+
'INTERACTION': ['TEAM BUILDING', 'NETWORKING MEETUP'],
|
131 |
+
'BEVERAGE': ['LATTE', 'GREEN TEA'],
|
132 |
+
'PLAN': ['WEIGHT LOSS', 'CAREER DEVELOPMENT'],
|
133 |
+
'GEO': ['PARIS', 'NEW YORK'],
|
134 |
+
'GEAR': ['CAMERA', 'HEADPHONES'],
|
135 |
+
'EMOJI': ['🏢', '🌍'],
|
136 |
+
'BEHAVIOR': ['POSITIVE THINKING', 'STRESS MANAGEMENT'],
|
137 |
+
'TONE': ['FRIENDLY', 'PROFESSIONAL'],
|
138 |
+
'LOCATION': ['DOWNTOWN', 'SUBURBS']
|
139 |
+
}}
|
140 |
+
################
|
141 |
+
Output:
|
142 |
+
{{
|
143 |
+
"answer_type_keywords": ["LOCATION"],
|
144 |
+
"entities_from_query": ["capital of the United States", "Washington", "New York"]
|
145 |
+
}}
|
146 |
+
#############################
|
147 |
+
|
148 |
+
-Real Data-
|
149 |
+
######################
|
150 |
+
Query: {query}
|
151 |
+
Answer type pool:{TYPE_POOL}
|
152 |
+
######################
|
153 |
+
Output:
|
154 |
+
|
155 |
+
"""
|
156 |
+
|
157 |
+
PROMPTS["keywords_extraction"] = """---Role---
|
158 |
+
|
159 |
+
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
|
160 |
+
|
161 |
+
---Goal---
|
162 |
+
|
163 |
+
Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
|
164 |
+
|
165 |
+
---Instructions---
|
166 |
+
|
167 |
+
- Output the keywords in JSON format.
|
168 |
+
- The JSON should have two keys:
|
169 |
+
- "high_level_keywords" for overarching concepts or themes.
|
170 |
+
- "low_level_keywords" for specific entities or details.
|
171 |
+
|
172 |
+
######################
|
173 |
+
-Examples-
|
174 |
+
######################
|
175 |
+
{examples}
|
176 |
+
|
177 |
+
#############################
|
178 |
+
-Real Data-
|
179 |
+
######################
|
180 |
+
Query: {query}
|
181 |
+
######################
|
182 |
+
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
|
183 |
+
Output:
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
PROMPTS["keywords_extraction_examples"] = [
|
188 |
+
"""Example 1:
|
189 |
+
|
190 |
+
Query: "How does international trade influence global economic stability?"
|
191 |
+
################
|
192 |
+
Output:
|
193 |
+
{
|
194 |
+
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
|
195 |
+
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
|
196 |
+
}
|
197 |
+
#############################""",
|
198 |
+
"""Example 2:
|
199 |
+
|
200 |
+
Query: "What are the environmental consequences of deforestation on biodiversity?"
|
201 |
+
################
|
202 |
+
Output:
|
203 |
+
{
|
204 |
+
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
|
205 |
+
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
|
206 |
+
}
|
207 |
+
#############################""",
|
208 |
+
"""Example 3:
|
209 |
+
|
210 |
+
Query: "What is the role of education in reducing poverty?"
|
211 |
+
################
|
212 |
+
Output:
|
213 |
+
{
|
214 |
+
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
|
215 |
+
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
|
216 |
+
}
|
217 |
+
#############################""",
|
218 |
+
]
|
graphrag/search.py
CHANGED
@@ -14,90 +14,313 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
|
|
|
|
17 |
from copy import deepcopy
|
18 |
-
|
19 |
import pandas as pd
|
20 |
-
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
|
21 |
|
22 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
class KGSearch(Dealer):
|
26 |
-
def
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
break
|
47 |
-
first_source["content_with_weight"] = content_with_weight
|
48 |
-
first_id = next(iter(sres))
|
49 |
-
return {first_id: first_source}
|
50 |
-
|
51 |
-
qst = req.get("question", "")
|
52 |
-
matchText, keywords = self.qryr.question(qst, min_match=0.05)
|
53 |
-
condition = self.get_filters(req)
|
54 |
-
|
55 |
-
## Entity retrieval
|
56 |
-
condition.update({"knowledge_graph_kwd": ["entity"]})
|
57 |
-
assert emb_mdl, "No embedding model selected"
|
58 |
-
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
|
59 |
-
q_vec = matchDense.embedding_data
|
60 |
-
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
61 |
-
"doc_id", f"q_{len(q_vec)}_vec", "position_int", "name_kwd",
|
62 |
-
"available_int", "content_with_weight",
|
63 |
-
"weight_int", "weight_flt"
|
64 |
-
])
|
65 |
-
|
66 |
-
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
|
67 |
-
|
68 |
-
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
69 |
-
ent_res_fields = self.dataStore.getFields(ent_res, src)
|
70 |
-
entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")]
|
71 |
-
ent_ids = self.dataStore.getChunkIds(ent_res)
|
72 |
-
ent_content = merge_into_first(ent_res_fields, "-Entities-")
|
73 |
-
if ent_content:
|
74 |
-
ent_ids = list(ent_content.keys())
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
## Community retrieval
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
if
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
17 |
+
import logging
|
18 |
+
from collections import defaultdict
|
19 |
from copy import deepcopy
|
20 |
+
import json_repair
|
21 |
import pandas as pd
|
|
|
22 |
|
23 |
+
from api.utils import get_uuid
|
24 |
+
from graphrag.query_analyze_prompt import PROMPTS
|
25 |
+
from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation
|
26 |
+
from rag.utils import num_tokens_from_string
|
27 |
+
from rag.utils.doc_store_conn import OrderByExpr
|
28 |
+
|
29 |
+
from rag.nlp.search import Dealer, index_name
|
30 |
|
31 |
|
32 |
class KGSearch(Dealer):
|
33 |
+
def _chat(self, llm_bdl, system, history, gen_conf):
|
34 |
+
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
|
35 |
+
if response:
|
36 |
+
return response
|
37 |
+
response = llm_bdl.chat(system, history, gen_conf)
|
38 |
+
if response.find("**ERROR**") >= 0:
|
39 |
+
raise Exception(response)
|
40 |
+
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
|
41 |
+
return response
|
42 |
+
|
43 |
+
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
44 |
+
ty2ents = get_entity_type2sampels(idxnms, kb_ids)
|
45 |
+
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
46 |
+
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
47 |
+
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5})
|
48 |
+
try:
|
49 |
+
keywords_data = json_repair.loads(result)
|
50 |
+
type_keywords = keywords_data.get("answer_type_keywords", [])
|
51 |
+
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
|
52 |
+
return type_keywords, entities_from_query
|
53 |
+
except json_repair.JSONDecodeError:
|
54 |
+
try:
|
55 |
+
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
56 |
+
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
57 |
+
keywords_data = json_repair.loads(result)
|
58 |
+
type_keywords = keywords_data.get("answer_type_keywords", [])
|
59 |
+
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
|
60 |
+
return type_keywords, entities_from_query
|
61 |
+
# Handle parsing error
|
62 |
+
except Exception as e:
|
63 |
+
logging.exception(f"JSON parsing error: {result} -> {e}")
|
64 |
+
raise e
|
65 |
+
|
66 |
+
def _ent_info_from_(self, es_res, sim_thr=0.3):
|
67 |
+
res = {}
|
68 |
+
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt",
|
69 |
+
"n_hop_with_weight"])
|
70 |
+
for _, ent in es_res.items():
|
71 |
+
if float(ent.get("_score", 0)) < sim_thr:
|
72 |
+
continue
|
73 |
+
if isinstance(ent["entity_kwd"], list):
|
74 |
+
ent["entity_kwd"] = ent["entity_kwd"][0]
|
75 |
+
res[ent["entity_kwd"]] = {
|
76 |
+
"sim": float(ent.get("_score", 0)),
|
77 |
+
"pagerank": float(ent.get("rank_flt", 0)),
|
78 |
+
"n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
|
79 |
+
"description": ent.get("content_with_weight", "{}")
|
80 |
+
}
|
81 |
+
return res
|
82 |
+
|
83 |
+
def _relation_info_from_(self, es_res, sim_thr=0.3):
|
84 |
+
res = {}
|
85 |
+
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
86 |
+
"weight_int"])
|
87 |
+
for _, ent in es_res.items():
|
88 |
+
if float(ent["_score"]) < sim_thr:
|
89 |
+
continue
|
90 |
+
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
|
91 |
+
if isinstance(f, list):
|
92 |
+
f = f[0]
|
93 |
+
if isinstance(t, list):
|
94 |
+
t = t[0]
|
95 |
+
res[(f, t)] = {
|
96 |
+
"sim": float(ent["_score"]),
|
97 |
+
"pagerank": float(ent.get("weight_int", 0)),
|
98 |
+
"description": ent["content_with_weight"]
|
99 |
+
}
|
100 |
+
return res
|
101 |
+
|
102 |
+
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
|
103 |
+
if not keywords:
|
104 |
+
return {}
|
105 |
+
filters = deepcopy(filters)
|
106 |
+
filters["knowledge_graph_kwd"] = "entity"
|
107 |
+
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
|
108 |
+
es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense],
|
109 |
+
OrderByExpr(), 0, N,
|
110 |
+
idxnms, kb_ids)
|
111 |
+
return self._ent_info_from_(es_res, sim_thr)
|
112 |
+
|
113 |
+
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
|
114 |
+
if not txt:
|
115 |
+
return {}
|
116 |
+
filters = deepcopy(filters)
|
117 |
+
filters["knowledge_graph_kwd"] = "relation"
|
118 |
+
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
|
119 |
+
es_res = self.dataStore.search(
|
120 |
+
["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
|
121 |
+
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
|
122 |
+
return self._relation_info_from_(es_res, sim_thr)
|
123 |
+
|
124 |
+
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
|
125 |
+
if not types:
|
126 |
+
return {}
|
127 |
+
filters = deepcopy(filters)
|
128 |
+
filters["knowledge_graph_kwd"] = "entity"
|
129 |
+
filters["entity_type_kwd"] = types
|
130 |
+
ordr = OrderByExpr()
|
131 |
+
ordr.desc("rank_flt")
|
132 |
+
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
|
133 |
+
idxnms, kb_ids)
|
134 |
+
return self._ent_info_from_(es_res, 0)
|
135 |
+
|
136 |
+
def retrieval(self, question: str,
|
137 |
+
tenant_ids: str | list[str],
|
138 |
+
kb_ids: list[str],
|
139 |
+
emb_mdl,
|
140 |
+
llm,
|
141 |
+
max_token: int = 8196,
|
142 |
+
ent_topn: int = 6,
|
143 |
+
rel_topn: int = 6,
|
144 |
+
comm_topn: int = 1,
|
145 |
+
ent_sim_threshold: float = 0.3,
|
146 |
+
rel_sim_threshold: float = 0.3,
|
147 |
+
):
|
148 |
+
qst = question
|
149 |
+
filters = self.get_filters({"kb_ids": kb_ids})
|
150 |
+
if isinstance(tenant_ids, str):
|
151 |
+
tenant_ids = tenant_ids.split(",")
|
152 |
+
idxnms = [index_name(tid) for tid in tenant_ids]
|
153 |
+
ty_kwds = []
|
154 |
+
ents = []
|
155 |
+
try:
|
156 |
+
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
|
157 |
+
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
|
158 |
+
except Exception as e:
|
159 |
+
logging.exception(e)
|
160 |
+
ents = [qst]
|
161 |
+
pass
|
162 |
+
|
163 |
+
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
|
164 |
+
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
|
165 |
+
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
|
166 |
+
nhop_pathes = defaultdict(dict)
|
167 |
+
for _, ent in ents_from_query.items():
|
168 |
+
nhops = ent.get("n_hop_ents", [])
|
169 |
+
for nbr in nhops:
|
170 |
+
path = nbr["path"]
|
171 |
+
wts = nbr["weights"]
|
172 |
+
for i in range(len(path) - 1):
|
173 |
+
f, t = path[i], path[i + 1]
|
174 |
+
if (f, t) in nhop_pathes:
|
175 |
+
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
|
176 |
+
else:
|
177 |
+
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
|
178 |
+
nhop_pathes[(f, t)]["pagerank"] = wts[i]
|
179 |
+
|
180 |
+
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
|
181 |
+
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
|
182 |
+
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
|
183 |
+
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
|
184 |
+
|
185 |
+
# P(E|Q) => P(E) * P(Q|E) => pagerank * sim
|
186 |
+
for ent in ents_from_types.keys():
|
187 |
+
if ent not in ents_from_query:
|
188 |
+
continue
|
189 |
+
ents_from_query[ent]["sim"] *= 2
|
190 |
+
|
191 |
+
for (f, t) in rels_from_txt.keys():
|
192 |
+
pair = tuple(sorted([f, t]))
|
193 |
+
s = 0
|
194 |
+
if pair in nhop_pathes:
|
195 |
+
s += nhop_pathes[pair]["sim"]
|
196 |
+
del nhop_pathes[pair]
|
197 |
+
if f in ents_from_types:
|
198 |
+
s += 1
|
199 |
+
if t in ents_from_types:
|
200 |
+
s += 1
|
201 |
+
rels_from_txt[(f, t)]["sim"] *= s + 1
|
202 |
+
|
203 |
+
# This is for the relations from n-hop but not by query search
|
204 |
+
for (f, t) in nhop_pathes.keys():
|
205 |
+
s = 0
|
206 |
+
if f in ents_from_types:
|
207 |
+
s += 1
|
208 |
+
if t in ents_from_types:
|
209 |
+
s += 1
|
210 |
+
rels_from_txt[(f, t)] = {
|
211 |
+
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
|
212 |
+
"pagerank": nhop_pathes[(f, t)]["pagerank"]
|
213 |
+
}
|
214 |
+
|
215 |
+
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
|
216 |
+
:ent_topn]
|
217 |
+
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
|
218 |
+
:rel_topn]
|
219 |
+
|
220 |
+
ents = []
|
221 |
+
relas = []
|
222 |
+
for n, ent in ents_from_query:
|
223 |
+
ents.append({
|
224 |
+
"Entity": n,
|
225 |
+
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
|
226 |
+
"Description": json.loads(ent["description"]).get("description", "")
|
227 |
+
})
|
228 |
+
max_token -= num_tokens_from_string(str(ents[-1]))
|
229 |
+
if max_token <= 0:
|
230 |
+
ents = ents[:-1]
|
231 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
+
for (f, t), rel in rels_from_txt:
|
234 |
+
if not rel.get("description"):
|
235 |
+
for tid in tenant_ids:
|
236 |
+
rela = get_relation(tid, kb_ids, f, t)
|
237 |
+
if rela:
|
238 |
+
break
|
239 |
+
else:
|
240 |
+
continue
|
241 |
+
rel["description"] = rela["description"]
|
242 |
+
relas.append({
|
243 |
+
"From Entity": f,
|
244 |
+
"To Entity": t,
|
245 |
+
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
|
246 |
+
"Description": json.loads(ent["description"]).get("description", "")
|
247 |
+
})
|
248 |
+
max_token -= num_tokens_from_string(str(relas[-1]))
|
249 |
+
if max_token <= 0:
|
250 |
+
relas = relas[:-1]
|
251 |
+
break
|
252 |
+
|
253 |
+
if ents:
|
254 |
+
ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
|
255 |
+
else:
|
256 |
+
ents = ""
|
257 |
+
if relas:
|
258 |
+
relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
|
259 |
+
else:
|
260 |
+
relas = ""
|
261 |
+
|
262 |
+
return {
|
263 |
+
"chunk_id": get_uuid(),
|
264 |
+
"content_ltks": "",
|
265 |
+
"content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
|
266 |
+
comm_topn, max_token),
|
267 |
+
"doc_id": "",
|
268 |
+
"docnm_kwd": "Related content in Knowledge Graph",
|
269 |
+
"kb_id": kb_ids,
|
270 |
+
"important_kwd": [],
|
271 |
+
"image_id": "",
|
272 |
+
"similarity": 1.,
|
273 |
+
"vector_similarity": 1.,
|
274 |
+
"term_similarity": 0,
|
275 |
+
"vector": [],
|
276 |
+
"positions": [],
|
277 |
+
}
|
278 |
+
|
279 |
+
def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token):
|
280 |
## Community retrieval
|
281 |
+
fields = ["docnm_kwd", "content_with_weight"]
|
282 |
+
odr = OrderByExpr()
|
283 |
+
odr.desc("weight_flt")
|
284 |
+
fltr = deepcopy(condition)
|
285 |
+
fltr["knowledge_graph_kwd"] = "community_report"
|
286 |
+
fltr["entities_kwd"] = entities
|
287 |
+
comm_res = self.dataStore.search(fields, [], fltr, [],
|
288 |
+
OrderByExpr(), 0, topn, idxnms, kb_ids)
|
289 |
+
comm_res_fields = self.dataStore.getFields(comm_res, fields)
|
290 |
+
txts = []
|
291 |
+
for ii, (_, row) in enumerate(comm_res_fields.items()):
|
292 |
+
obj = json.loads(row["content_with_weight"])
|
293 |
+
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
|
294 |
+
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
|
295 |
+
max_token -= num_tokens_from_string(str(txts[-1]))
|
296 |
+
|
297 |
+
if not txts:
|
298 |
+
return ""
|
299 |
+
return "\n-Community Report-\n" + "\n".join(txts)
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == "__main__":
|
303 |
+
from api import settings
|
304 |
+
import argparse
|
305 |
+
from api.db import LLMType
|
306 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
307 |
+
from api.db.services.llm_service import LLMBundle
|
308 |
+
from api.db.services.user_service import TenantService
|
309 |
+
from rag.nlp import search
|
310 |
+
|
311 |
+
settings.init_settings()
|
312 |
+
parser = argparse.ArgumentParser()
|
313 |
+
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
314 |
+
parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
|
315 |
+
parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
|
316 |
+
args = parser.parse_args()
|
317 |
+
|
318 |
+
kb_id = args.kb_id
|
319 |
+
_, tenant = TenantService.get_by_id(args.tenant_id)
|
320 |
+
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
321 |
+
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
322 |
+
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
323 |
+
|
324 |
+
kg = KGSearch(settings.docStoreConn)
|
325 |
+
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
326 |
+
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
graphrag/utils.py
CHANGED
@@ -3,16 +3,26 @@
|
|
3 |
"""
|
4 |
Reference:
|
5 |
- [graphrag](https://github.com/microsoft/graphrag)
|
|
|
6 |
"""
|
7 |
|
8 |
import html
|
9 |
import json
|
|
|
10 |
import re
|
|
|
|
|
|
|
|
|
11 |
from typing import Any, Callable
|
12 |
|
|
|
13 |
import numpy as np
|
14 |
import xxhash
|
|
|
15 |
|
|
|
|
|
16 |
from rag.utils.redis_conn import REDIS_CONN
|
17 |
|
18 |
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
@@ -131,3 +141,379 @@ def set_tags_to_cache(kb_ids, tags):
|
|
131 |
|
132 |
k = hasher.hexdigest()
|
133 |
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
Reference:
|
5 |
- [graphrag](https://github.com/microsoft/graphrag)
|
6 |
+
- [LightRag](https://github.com/HKUDS/LightRAG)
|
7 |
"""
|
8 |
|
9 |
import html
|
10 |
import json
|
11 |
+
import logging
|
12 |
import re
|
13 |
+
import time
|
14 |
+
from collections import defaultdict
|
15 |
+
from copy import deepcopy
|
16 |
+
from hashlib import md5
|
17 |
from typing import Any, Callable
|
18 |
|
19 |
+
import networkx as nx
|
20 |
import numpy as np
|
21 |
import xxhash
|
22 |
+
from networkx.readwrite import json_graph
|
23 |
|
24 |
+
from api import settings
|
25 |
+
from rag.nlp import search, rag_tokenizer
|
26 |
from rag.utils.redis_conn import REDIS_CONN
|
27 |
|
28 |
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
|
|
141 |
|
142 |
k = hasher.hexdigest()
|
143 |
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
144 |
+
|
145 |
+
|
146 |
+
def graph_merge(g1, g2):
|
147 |
+
g = g2.copy()
|
148 |
+
for n, attr in g1.nodes(data=True):
|
149 |
+
if n not in g2.nodes():
|
150 |
+
g.add_node(n, **attr)
|
151 |
+
continue
|
152 |
+
|
153 |
+
for source, target, attr in g1.edges(data=True):
|
154 |
+
if g.has_edge(source, target):
|
155 |
+
g[source][target].update({"weight": attr.get("weight", 0)+1})
|
156 |
+
continue
|
157 |
+
g.add_edge(source, target)#, **attr)
|
158 |
+
|
159 |
+
for node_degree in g.degree:
|
160 |
+
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
161 |
+
return g
|
162 |
+
|
163 |
+
|
164 |
+
def compute_args_hash(*args):
|
165 |
+
return md5(str(args).encode()).hexdigest()
|
166 |
+
|
167 |
+
|
168 |
+
def handle_single_entity_extraction(
|
169 |
+
record_attributes: list[str],
|
170 |
+
chunk_key: str,
|
171 |
+
):
|
172 |
+
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
173 |
+
return None
|
174 |
+
# add this record as a node in the G
|
175 |
+
entity_name = clean_str(record_attributes[1].upper())
|
176 |
+
if not entity_name.strip():
|
177 |
+
return None
|
178 |
+
entity_type = clean_str(record_attributes[2].upper())
|
179 |
+
entity_description = clean_str(record_attributes[3])
|
180 |
+
entity_source_id = chunk_key
|
181 |
+
return dict(
|
182 |
+
entity_name=entity_name.upper(),
|
183 |
+
entity_type=entity_type.upper(),
|
184 |
+
description=entity_description,
|
185 |
+
source_id=entity_source_id,
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
|
190 |
+
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
191 |
+
return None
|
192 |
+
# add this record as edge
|
193 |
+
source = clean_str(record_attributes[1].upper())
|
194 |
+
target = clean_str(record_attributes[2].upper())
|
195 |
+
edge_description = clean_str(record_attributes[3])
|
196 |
+
|
197 |
+
edge_keywords = clean_str(record_attributes[4])
|
198 |
+
edge_source_id = chunk_key
|
199 |
+
weight = (
|
200 |
+
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
201 |
+
)
|
202 |
+
pair = sorted([source.upper(), target.upper()])
|
203 |
+
return dict(
|
204 |
+
src_id=pair[0],
|
205 |
+
tgt_id=pair[1],
|
206 |
+
weight=weight,
|
207 |
+
description=edge_description,
|
208 |
+
keywords=edge_keywords,
|
209 |
+
source_id=edge_source_id,
|
210 |
+
metadata={"created_at": time.time()},
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
def pack_user_ass_to_openai_messages(*args: str):
|
215 |
+
roles = ["user", "assistant"]
|
216 |
+
return [
|
217 |
+
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
218 |
+
]
|
219 |
+
|
220 |
+
|
221 |
+
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
222 |
+
"""Split a string by multiple markers"""
|
223 |
+
if not markers:
|
224 |
+
return [content]
|
225 |
+
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
226 |
+
return [r.strip() for r in results if r.strip()]
|
227 |
+
|
228 |
+
|
229 |
+
def is_float_regex(value):
|
230 |
+
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
231 |
+
|
232 |
+
|
233 |
+
def chunk_id(chunk):
|
234 |
+
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
235 |
+
|
236 |
+
|
237 |
+
def get_entity(tenant_id, kb_id, ent_name):
|
238 |
+
conds = {
|
239 |
+
"fields": ["content_with_weight"],
|
240 |
+
"entity_kwd": ent_name,
|
241 |
+
"size": 10000,
|
242 |
+
"knowledge_graph_kwd": ["entity"]
|
243 |
+
}
|
244 |
+
res = []
|
245 |
+
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
246 |
+
for id in es_res.ids:
|
247 |
+
try:
|
248 |
+
if isinstance(ent_name, str):
|
249 |
+
return json.loads(es_res.field[id]["content_with_weight"])
|
250 |
+
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
251 |
+
except Exception:
|
252 |
+
continue
|
253 |
+
|
254 |
+
return res
|
255 |
+
|
256 |
+
|
257 |
+
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
258 |
+
chunk = {
|
259 |
+
"important_kwd": [ent_name],
|
260 |
+
"title_tks": rag_tokenizer.tokenize(ent_name),
|
261 |
+
"entity_kwd": ent_name,
|
262 |
+
"knowledge_graph_kwd": "entity",
|
263 |
+
"entity_type_kwd": meta["entity_type"],
|
264 |
+
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
265 |
+
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
266 |
+
"source_id": list(set(meta["source_id"])),
|
267 |
+
"kb_id": kb_id,
|
268 |
+
"available_int": 0
|
269 |
+
}
|
270 |
+
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
271 |
+
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
272 |
+
search.index_name(tenant_id), [kb_id])
|
273 |
+
if res.ids:
|
274 |
+
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
|
275 |
+
else:
|
276 |
+
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
277 |
+
if ebd is None:
|
278 |
+
try:
|
279 |
+
ebd, _ = embd_mdl.encode([ent_name])
|
280 |
+
ebd = ebd[0]
|
281 |
+
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
282 |
+
except Exception as e:
|
283 |
+
logging.exception(f"Fail to embed entity: {e}")
|
284 |
+
if ebd is not None:
|
285 |
+
chunk["q_%d_vec" % len(ebd)] = ebd
|
286 |
+
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
287 |
+
|
288 |
+
|
289 |
+
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
290 |
+
ents = from_ent_name
|
291 |
+
if isinstance(ents, str):
|
292 |
+
ents = [from_ent_name]
|
293 |
+
if isinstance(to_ent_name, str):
|
294 |
+
to_ent_name = [to_ent_name]
|
295 |
+
ents.extend(to_ent_name)
|
296 |
+
ents = list(set(ents))
|
297 |
+
conds = {
|
298 |
+
"fields": ["content_with_weight"],
|
299 |
+
"size": size,
|
300 |
+
"from_entity_kwd": ents,
|
301 |
+
"to_entity_kwd": ents,
|
302 |
+
"knowledge_graph_kwd": ["relation"]
|
303 |
+
}
|
304 |
+
res = []
|
305 |
+
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
306 |
+
for id in es_res.ids:
|
307 |
+
try:
|
308 |
+
if size == 1:
|
309 |
+
return json.loads(es_res.field[id]["content_with_weight"])
|
310 |
+
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
311 |
+
except Exception:
|
312 |
+
continue
|
313 |
+
return res
|
314 |
+
|
315 |
+
|
316 |
+
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
|
317 |
+
chunk = {
|
318 |
+
"from_entity_kwd": from_ent_name,
|
319 |
+
"to_entity_kwd": to_ent_name,
|
320 |
+
"knowledge_graph_kwd": "relation",
|
321 |
+
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
322 |
+
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
323 |
+
"important_kwd": meta["keywords"],
|
324 |
+
"source_id": list(set(meta["source_id"])),
|
325 |
+
"weight_int": int(meta["weight"]),
|
326 |
+
"kb_id": kb_id,
|
327 |
+
"available_int": 0
|
328 |
+
}
|
329 |
+
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
330 |
+
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
|
331 |
+
search.index_name(tenant_id), [kb_id])
|
332 |
+
|
333 |
+
if res.ids:
|
334 |
+
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
|
335 |
+
chunk,
|
336 |
+
search.index_name(tenant_id), kb_id)
|
337 |
+
else:
|
338 |
+
txt = f"{from_ent_name}->{to_ent_name}"
|
339 |
+
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
340 |
+
if ebd is None:
|
341 |
+
try:
|
342 |
+
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
|
343 |
+
ebd = ebd[0]
|
344 |
+
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
345 |
+
except Exception as e:
|
346 |
+
logging.exception(f"Fail to embed entity relation: {e}")
|
347 |
+
if ebd is not None:
|
348 |
+
chunk["q_%d_vec" % len(ebd)] = ebd
|
349 |
+
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
350 |
+
|
351 |
+
|
352 |
+
def get_graph(tenant_id, kb_id):
|
353 |
+
conds = {
|
354 |
+
"fields": ["content_with_weight", "source_id"],
|
355 |
+
"removed_kwd": "N",
|
356 |
+
"size": 1,
|
357 |
+
"knowledge_graph_kwd": ["graph"]
|
358 |
+
}
|
359 |
+
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
360 |
+
for id in res.ids:
|
361 |
+
try:
|
362 |
+
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
|
363 |
+
res.field[id]["source_id"]
|
364 |
+
except Exception:
|
365 |
+
continue
|
366 |
+
return None, None
|
367 |
+
|
368 |
+
|
369 |
+
def set_graph(tenant_id, kb_id, graph, docids):
|
370 |
+
chunk = {
|
371 |
+
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
|
372 |
+
indent=2),
|
373 |
+
"knowledge_graph_kwd": "graph",
|
374 |
+
"kb_id": kb_id,
|
375 |
+
"source_id": list(docids),
|
376 |
+
"available_int": 0,
|
377 |
+
"removed_kwd": "N"
|
378 |
+
}
|
379 |
+
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
|
380 |
+
if res.ids:
|
381 |
+
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
|
382 |
+
search.index_name(tenant_id), kb_id)
|
383 |
+
else:
|
384 |
+
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
385 |
+
|
386 |
+
|
387 |
+
def is_continuous_subsequence(subseq, seq):
|
388 |
+
def find_all_indexes(tup, value):
|
389 |
+
indexes = []
|
390 |
+
start = 0
|
391 |
+
while True:
|
392 |
+
try:
|
393 |
+
index = tup.index(value, start)
|
394 |
+
indexes.append(index)
|
395 |
+
start = index + 1
|
396 |
+
except ValueError:
|
397 |
+
break
|
398 |
+
return indexes
|
399 |
+
|
400 |
+
index_list = find_all_indexes(seq,subseq[0])
|
401 |
+
for idx in index_list:
|
402 |
+
if idx!=len(seq)-1:
|
403 |
+
if seq[idx+1]==subseq[-1]:
|
404 |
+
return True
|
405 |
+
return False
|
406 |
+
|
407 |
+
|
408 |
+
def merge_tuples(list1, list2):
|
409 |
+
result = []
|
410 |
+
for tup in list1:
|
411 |
+
last_element = tup[-1]
|
412 |
+
if last_element in tup[:-1]:
|
413 |
+
result.append(tup)
|
414 |
+
else:
|
415 |
+
matching_tuples = [t for t in list2 if t[0] == last_element]
|
416 |
+
already_match_flag = 0
|
417 |
+
for match in matching_tuples:
|
418 |
+
matchh = (match[1], match[0])
|
419 |
+
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
|
420 |
+
continue
|
421 |
+
already_match_flag = 1
|
422 |
+
merged_tuple = tup + match[1:]
|
423 |
+
result.append(merged_tuple)
|
424 |
+
if not already_match_flag:
|
425 |
+
result.append(tup)
|
426 |
+
return result
|
427 |
+
|
428 |
+
|
429 |
+
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
430 |
+
def n_neighbor(id):
|
431 |
+
nonlocal graph, n_hop
|
432 |
+
count = 0
|
433 |
+
source_edge = list(graph.edges(id))
|
434 |
+
if not source_edge:
|
435 |
+
return []
|
436 |
+
count = count + 1
|
437 |
+
while count < n_hop:
|
438 |
+
count = count + 1
|
439 |
+
sc_edge = deepcopy(source_edge)
|
440 |
+
source_edge = []
|
441 |
+
for pair in sc_edge:
|
442 |
+
append_edge = list(graph.edges(pair[-1]))
|
443 |
+
for tuples in merge_tuples([pair], append_edge):
|
444 |
+
source_edge.append(tuples)
|
445 |
+
nbrs = []
|
446 |
+
for path in source_edge:
|
447 |
+
n = {"path": path, "weights": []}
|
448 |
+
wts = nx.get_edge_attributes(graph, 'weight')
|
449 |
+
for i in range(len(path)-1):
|
450 |
+
f, t = path[i], path[i+1]
|
451 |
+
n["weights"].append(wts.get((f, t), 0))
|
452 |
+
nbrs.append(n)
|
453 |
+
return nbrs
|
454 |
+
|
455 |
+
pr = nx.pagerank(graph)
|
456 |
+
for n, p in pr.items():
|
457 |
+
graph.nodes[n]["pagerank"] = p
|
458 |
+
try:
|
459 |
+
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
460 |
+
{"rank_flt": p,
|
461 |
+
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
|
462 |
+
search.index_name(tenant_id), kb_id)
|
463 |
+
except Exception as e:
|
464 |
+
logging.exception(e)
|
465 |
+
|
466 |
+
ty2ents = defaultdict(list)
|
467 |
+
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
468 |
+
ty = graph.nodes[p].get("entity_type")
|
469 |
+
if not ty or len(ty2ents[ty]) > 12:
|
470 |
+
continue
|
471 |
+
ty2ents[ty].append(p)
|
472 |
+
|
473 |
+
chunk = {
|
474 |
+
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
|
475 |
+
"kb_id": kb_id,
|
476 |
+
"knowledge_graph_kwd": "ty2ents",
|
477 |
+
"available_int": 0
|
478 |
+
}
|
479 |
+
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
|
480 |
+
search.index_name(tenant_id), [kb_id])
|
481 |
+
if res.ids:
|
482 |
+
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
|
483 |
+
chunk,
|
484 |
+
search.index_name(tenant_id), kb_id)
|
485 |
+
else:
|
486 |
+
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
487 |
+
|
488 |
+
|
489 |
+
def get_entity_type2sampels(idxnms, kb_ids: list):
|
490 |
+
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
491 |
+
"size": 10000,
|
492 |
+
"fields": ["content_with_weight"]},
|
493 |
+
idxnms, kb_ids)
|
494 |
+
|
495 |
+
res = defaultdict(list)
|
496 |
+
for id in es_res.ids:
|
497 |
+
smp = es_res.field[id].get("content_with_weight")
|
498 |
+
if not smp:
|
499 |
+
continue
|
500 |
+
try:
|
501 |
+
smp = json.loads(smp)
|
502 |
+
except Exception as e:
|
503 |
+
logging.exception(e)
|
504 |
+
|
505 |
+
for ty, ents in smp.items():
|
506 |
+
res[ty].extend(ents)
|
507 |
+
return res
|
508 |
+
|
509 |
+
|
510 |
+
def flat_uniq_list(arr, key):
|
511 |
+
res = []
|
512 |
+
for a in arr:
|
513 |
+
a = a[key]
|
514 |
+
if isinstance(a, list):
|
515 |
+
res.extend(a)
|
516 |
+
else:
|
517 |
+
res.append(a)
|
518 |
+
return list(set(res))
|
519 |
+
|
pyproject.toml
CHANGED
@@ -51,6 +51,7 @@ dependencies = [
|
|
51 |
"infinity-sdk==0.6.0-dev2",
|
52 |
"infinity-emb>=0.0.66,<0.0.67",
|
53 |
"itsdangerous==2.1.2",
|
|
|
54 |
"markdown==3.6",
|
55 |
"markdown-to-json==2.1.1",
|
56 |
"minio==7.2.4",
|
@@ -130,4 +131,4 @@ full = [
|
|
130 |
"flagembedding==1.2.10",
|
131 |
"torch==2.3.0",
|
132 |
"transformers==4.38.1"
|
133 |
-
]
|
|
|
51 |
"infinity-sdk==0.6.0-dev2",
|
52 |
"infinity-emb>=0.0.66,<0.0.67",
|
53 |
"itsdangerous==2.1.2",
|
54 |
+
"json-repair==0.35.0",
|
55 |
"markdown==3.6",
|
56 |
"markdown-to-json==2.1.1",
|
57 |
"minio==7.2.4",
|
|
|
131 |
"flagembedding==1.2.10",
|
132 |
"torch==2.3.0",
|
133 |
"transformers==4.38.1"
|
134 |
+
]
|
rag/app/book.py
CHANGED
@@ -88,9 +88,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
88 |
callback(0.8, "Finish parsing.")
|
89 |
|
90 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
91 |
-
pdf_parser = Pdf()
|
92 |
-
|
93 |
-
|
94 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
95 |
from_page=from_page, to_page=to_page, callback=callback)
|
96 |
|
|
|
88 |
callback(0.8, "Finish parsing.")
|
89 |
|
90 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
91 |
+
pdf_parser = Pdf()
|
92 |
+
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
93 |
+
pdf_parser = PlainParser()
|
94 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
95 |
from_page=from_page, to_page=to_page, callback=callback)
|
96 |
|
rag/app/email.py
CHANGED
@@ -40,7 +40,7 @@ def chunk(
|
|
40 |
eng = lang.lower() == "english" # is_english(cks)
|
41 |
parser_config = kwargs.get(
|
42 |
"parser_config",
|
43 |
-
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize":
|
44 |
)
|
45 |
doc = {
|
46 |
"docnm_kwd": filename,
|
|
|
40 |
eng = lang.lower() == "english" # is_english(cks)
|
41 |
parser_config = kwargs.get(
|
42 |
"parser_config",
|
43 |
+
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
|
44 |
)
|
45 |
doc = {
|
46 |
"docnm_kwd": filename,
|
rag/app/knowledge_graph.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
#
|
16 |
-
|
17 |
-
import re
|
18 |
-
|
19 |
-
from graphrag.index import build_knowledge_graph_chunks
|
20 |
-
from rag.app import naive
|
21 |
-
from rag.nlp import rag_tokenizer, tokenize_chunks
|
22 |
-
|
23 |
-
|
24 |
-
def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
|
25 |
-
lang="Chinese", callback=None, **kwargs):
|
26 |
-
parser_config = kwargs.get(
|
27 |
-
"parser_config", {
|
28 |
-
"chunk_token_num": 512, "delimiter": "\n!?;。;!?", "layout_recognize": True})
|
29 |
-
eng = lang.lower() == "english"
|
30 |
-
|
31 |
-
parser_config["layout_recognize"] = True
|
32 |
-
sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True,
|
33 |
-
parser_config=parser_config, callback=callback)
|
34 |
-
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
|
35 |
-
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
|
36 |
-
)
|
37 |
-
for c in chunks:
|
38 |
-
c["docnm_kwd"] = filename
|
39 |
-
|
40 |
-
doc = {
|
41 |
-
"docnm_kwd": filename,
|
42 |
-
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
43 |
-
"knowledge_graph_kwd": "text"
|
44 |
-
}
|
45 |
-
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
46 |
-
chunks.extend(tokenize_chunks(sections, doc, eng))
|
47 |
-
|
48 |
-
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag/app/laws.py
CHANGED
@@ -162,9 +162,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
162 |
return tokenize_chunks(chunks, doc, eng, None)
|
163 |
|
164 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
165 |
-
pdf_parser = Pdf()
|
166 |
-
|
167 |
-
|
168 |
for txt, poss in pdf_parser(filename if not binary else binary,
|
169 |
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
170 |
sections.append(txt + poss)
|
|
|
162 |
return tokenize_chunks(chunks, doc, eng, None)
|
163 |
|
164 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
165 |
+
pdf_parser = Pdf()
|
166 |
+
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
167 |
+
pdf_parser = PlainParser()
|
168 |
for txt, poss in pdf_parser(filename if not binary else binary,
|
169 |
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
170 |
sections.append(txt + poss)
|
rag/app/manual.py
CHANGED
@@ -184,9 +184,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
184 |
# is it English
|
185 |
eng = lang.lower() == "english" # pdf_parser.is_english
|
186 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
187 |
-
pdf_parser = Pdf()
|
188 |
-
|
189 |
-
|
190 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
191 |
from_page=from_page, to_page=to_page, callback=callback)
|
192 |
if sections and len(sections[0]) < 3:
|
|
|
184 |
# is it English
|
185 |
eng = lang.lower() == "english" # pdf_parser.is_english
|
186 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
187 |
+
pdf_parser = Pdf()
|
188 |
+
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
189 |
+
pdf_parser = PlainParser()
|
190 |
sections, tbls = pdf_parser(filename if not binary else binary,
|
191 |
from_page=from_page, to_page=to_page, callback=callback)
|
192 |
if sections and len(sections[0]) < 3:
|
rag/app/naive.py
CHANGED
@@ -202,7 +202,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
202 |
is_english = lang.lower() == "english" # is_english(cks)
|
203 |
parser_config = kwargs.get(
|
204 |
"parser_config", {
|
205 |
-
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize":
|
206 |
doc = {
|
207 |
"docnm_kwd": filename,
|
208 |
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
@@ -231,8 +231,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
231 |
return res
|
232 |
|
233 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
234 |
-
pdf_parser = Pdf()
|
235 |
-
|
|
|
|
|
|
|
236 |
res = tokenize_table(tables, doc, is_english)
|
237 |
|
238 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
|
|
202 |
is_english = lang.lower() == "english" # is_english(cks)
|
203 |
parser_config = kwargs.get(
|
204 |
"parser_config", {
|
205 |
+
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
206 |
doc = {
|
207 |
"docnm_kwd": filename,
|
208 |
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
231 |
return res
|
232 |
|
233 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
234 |
+
pdf_parser = Pdf()
|
235 |
+
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
236 |
+
pdf_parser = PlainParser()
|
237 |
+
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
238 |
+
callback=callback)
|
239 |
res = tokenize_table(tables, doc, is_english)
|
240 |
|
241 |
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
rag/app/one.py
CHANGED
@@ -84,9 +84,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
84 |
callback(0.8, "Finish parsing.")
|
85 |
|
86 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
87 |
-
pdf_parser = Pdf()
|
88 |
-
|
89 |
-
|
90 |
sections, _ = pdf_parser(
|
91 |
filename if not binary else binary, to_page=to_page, callback=callback)
|
92 |
sections = [s for s, _ in sections if s]
|
|
|
84 |
callback(0.8, "Finish parsing.")
|
85 |
|
86 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
87 |
+
pdf_parser = Pdf()
|
88 |
+
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
89 |
+
pdf_parser = PlainParser()
|
90 |
sections, _ = pdf_parser(
|
91 |
filename if not binary else binary, to_page=to_page, callback=callback)
|
92 |
sections = [s for s, _ in sections if s]
|
rag/app/paper.py
CHANGED
@@ -144,7 +144,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
144 |
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
145 |
"""
|
146 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
147 |
-
if
|
148 |
pdf_parser = PlainParser()
|
149 |
paper = {
|
150 |
"title": filename,
|
|
|
144 |
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
145 |
"""
|
146 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
147 |
+
if kwargs.get("parser_config", {}).get("layout_recognize", "DeepDOC") == "Plain Text":
|
148 |
pdf_parser = PlainParser()
|
149 |
paper = {
|
150 |
"title": filename,
|
rag/app/presentation.py
CHANGED
@@ -119,9 +119,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
119 |
res.append(d)
|
120 |
return res
|
121 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
122 |
-
pdf_parser = Pdf()
|
123 |
-
|
124 |
-
|
125 |
for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
|
126 |
from_page=from_page, to_page=to_page, callback=callback)):
|
127 |
d = copy.deepcopy(doc)
|
|
|
119 |
res.append(d)
|
120 |
return res
|
121 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
122 |
+
pdf_parser = Pdf()
|
123 |
+
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
124 |
+
pdf_parser = PlainParser()
|
125 |
for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
|
126 |
from_page=from_page, to_page=to_page, callback=callback)):
|
127 |
d = copy.deepcopy(doc)
|
rag/llm/chat_model.py
CHANGED
@@ -32,6 +32,7 @@ import asyncio
|
|
32 |
LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
|
33 |
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
|
34 |
|
|
|
35 |
class Base(ABC):
|
36 |
def __init__(self, key, model_name, base_url):
|
37 |
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
|
|
|
32 |
LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
|
33 |
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
|
34 |
|
35 |
+
|
36 |
class Base(ABC):
|
37 |
def __init__(self, key, model_name, base_url):
|
38 |
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
|
rag/nlp/search.py
CHANGED
@@ -59,7 +59,7 @@ class Dealer:
|
|
59 |
if key in req and req[key] is not None:
|
60 |
condition[field] = req[key]
|
61 |
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
62 |
-
for key in ["knowledge_graph_kwd", "available_int"]:
|
63 |
if key in req and req[key] is not None:
|
64 |
condition[key] = req[key]
|
65 |
return condition
|
@@ -198,6 +198,11 @@ class Dealer:
|
|
198 |
return answer, set([])
|
199 |
|
200 |
ans_v, _ = embd_mdl.encode(pieces_)
|
|
|
|
|
|
|
|
|
|
|
201 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
202 |
len(ans_v[0]), len(chunk_v[0]))
|
203 |
|
@@ -434,6 +439,8 @@ class Dealer:
|
|
434 |
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
435 |
kb_ids)
|
436 |
dict_chunks = self.dataStore.getFields(es_res, fields)
|
|
|
|
|
437 |
if dict_chunks:
|
438 |
res.extend(dict_chunks.values())
|
439 |
if len(dict_chunks.values()) < bs:
|
|
|
59 |
if key in req and req[key] is not None:
|
60 |
condition[field] = req[key]
|
61 |
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
62 |
+
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
|
63 |
if key in req and req[key] is not None:
|
64 |
condition[key] = req[key]
|
65 |
return condition
|
|
|
198 |
return answer, set([])
|
199 |
|
200 |
ans_v, _ = embd_mdl.encode(pieces_)
|
201 |
+
for i in range(len(chunk_v)):
|
202 |
+
if len(ans_v[0]) != len(chunk_v[i]):
|
203 |
+
chunk_v[i] = [0.0]*len(ans_v[0])
|
204 |
+
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
205 |
+
|
206 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
207 |
len(ans_v[0]), len(chunk_v[0]))
|
208 |
|
|
|
439 |
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
440 |
kb_ids)
|
441 |
dict_chunks = self.dataStore.getFields(es_res, fields)
|
442 |
+
for id, doc in dict_chunks.items():
|
443 |
+
doc["id"] = id
|
444 |
if dict_chunks:
|
445 |
res.extend(dict_chunks.values())
|
446 |
if len(dict_chunks.values()) < bs:
|