Kevin Hu commited on
Commit
47ec63e
·
1 Parent(s): 93f905e

Light GraphRAG (#4585)

Browse files

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. api/apps/chunk_app.py +12 -3
  2. api/apps/conversation_app.py +1 -1
  3. api/apps/kb_app.py +35 -1
  4. api/apps/sdk/dify_retrieval.py +13 -3
  5. api/apps/sdk/doc.py +13 -4
  6. api/db/init_data.py +1 -1
  7. api/db/services/dialog_service.py +9 -2
  8. api/db/services/document_service.py +46 -16
  9. api/db/services/file_service.py +1 -1
  10. api/db/services/task_service.py +20 -10
  11. api/utils/api_utils.py +1 -1
  12. conf/infinity_mapping.json +11 -1
  13. graphrag/description_summary.py +0 -146
  14. graphrag/entity_resolution.py +44 -40
  15. graphrag/extractor.py +0 -34
  16. graphrag/general/__init__.py +0 -0
  17. graphrag/{claim_extractor.py → general/claim_extractor.py} +2 -2
  18. graphrag/{claim_prompt.py → general/claim_prompt.py} +0 -0
  19. graphrag/{community_report_prompt.py → general/community_report_prompt.py} +0 -0
  20. graphrag/{community_reports_extractor.py → general/community_reports_extractor.py} +25 -14
  21. graphrag/{entity_embedding.py → general/entity_embedding.py} +1 -1
  22. graphrag/general/extractor.py +245 -0
  23. graphrag/general/graph_extractor.py +154 -0
  24. graphrag/{graph_prompt.py → general/graph_prompt.py} +16 -1
  25. graphrag/general/index.py +197 -0
  26. graphrag/{leiden.py → general/leiden.py} +2 -1
  27. graphrag/{mind_map_extractor.py → general/mind_map_extractor.py} +2 -2
  28. graphrag/{mind_map_prompt.py → general/mind_map_prompt.py} +0 -0
  29. graphrag/general/smoke.py +63 -0
  30. graphrag/graph_extractor.py +0 -322
  31. graphrag/index.py +0 -153
  32. graphrag/light/__init__.py +0 -0
  33. graphrag/light/graph_extractor.py +127 -0
  34. graphrag/light/graph_prompt.py +255 -0
  35. graphrag/{smoke.py → light/smoke.py} +28 -25
  36. graphrag/query_analyze_prompt.py +218 -0
  37. graphrag/search.py +301 -78
  38. graphrag/utils.py +386 -0
  39. pyproject.toml +2 -1
  40. rag/app/book.py +3 -3
  41. rag/app/email.py +1 -1
  42. rag/app/knowledge_graph.py +0 -48
  43. rag/app/laws.py +3 -3
  44. rag/app/manual.py +3 -3
  45. rag/app/naive.py +6 -3
  46. rag/app/one.py +3 -3
  47. rag/app/paper.py +1 -1
  48. rag/app/presentation.py +3 -3
  49. rag/llm/chat_model.py +1 -0
  50. rag/nlp/search.py +8 -1
api/apps/chunk_app.py CHANGED
@@ -155,7 +155,7 @@ def set():
155
  r"[\n\t]",
156
  req["content_with_weight"]) if len(t) > 1]
157
  q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
158
- d = beAdoc(d, arr[0], arr[1], not any(
159
  [rag_tokenizer.is_chinese(t) for t in q + a]))
160
 
161
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
@@ -270,6 +270,7 @@ def retrieval_test():
270
  doc_ids = req.get("doc_ids", [])
271
  similarity_threshold = float(req.get("similarity_threshold", 0.0))
272
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
 
273
  top = int(req.get("top_k", 1024))
274
  tenant_ids = []
275
 
@@ -301,12 +302,20 @@ def retrieval_test():
301
  question += keyword_extraction(chat_mdl, question)
302
 
303
  labels = label_question(question, [kb])
304
- retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
305
- ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
306
  similarity_threshold, vector_similarity_weight, top,
307
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
308
  rank_feature=labels
309
  )
 
 
 
 
 
 
 
 
 
310
  for c in ranks["chunks"]:
311
  c.pop("vector", None)
312
  ranks["labels"] = labels
 
155
  r"[\n\t]",
156
  req["content_with_weight"]) if len(t) > 1]
157
  q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
158
+ d = beAdoc(d, q, a, not any(
159
  [rag_tokenizer.is_chinese(t) for t in q + a]))
160
 
161
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
 
270
  doc_ids = req.get("doc_ids", [])
271
  similarity_threshold = float(req.get("similarity_threshold", 0.0))
272
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
273
+ use_kg = req.get("use_kg", False)
274
  top = int(req.get("top_k", 1024))
275
  tenant_ids = []
276
 
 
302
  question += keyword_extraction(chat_mdl, question)
303
 
304
  labels = label_question(question, [kb])
305
+ ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
 
306
  similarity_threshold, vector_similarity_weight, top,
307
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
308
  rank_feature=labels
309
  )
310
+ if use_kg:
311
+ ck = settings.kg_retrievaler.retrieval(question,
312
+ tenant_ids,
313
+ kb_ids,
314
+ embd_mdl,
315
+ LLMBundle(kb.tenant_id, LLMType.CHAT))
316
+ if ck["content_with_weight"]:
317
+ ranks["chunks"].insert(0, ck)
318
+
319
  for c in ranks["chunks"]:
320
  c.pop("vector", None)
321
  ranks["labels"] = labels
api/apps/conversation_app.py CHANGED
@@ -31,7 +31,7 @@ from api.db.services.llm_service import LLMBundle, TenantService
31
  from api import settings
32
  from api.utils.api_utils import get_json_result
33
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
34
- from graphrag.mind_map_extractor import MindMapExtractor
35
 
36
 
37
  @manager.route('/set', methods=['POST']) # noqa: F821
 
31
  from api import settings
32
  from api.utils.api_utils import get_json_result
33
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
34
+ from graphrag.general.mind_map_extractor import MindMapExtractor
35
 
36
 
37
  @manager.route('/set', methods=['POST']) # noqa: F821
api/apps/kb_app.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from flask import request
17
  from flask_login import login_required, current_user
18
 
@@ -272,4 +274,36 @@ def rename_tags(kb_id):
272
  {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
273
  search.index_name(kb.tenant_id),
274
  kb_id)
275
- return get_json_result(data=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import json
17
+
18
  from flask import request
19
  from flask_login import login_required, current_user
20
 
 
274
  {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
275
  search.index_name(kb.tenant_id),
276
  kb_id)
277
+ return get_json_result(data=True)
278
+
279
+
280
+ @manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
281
+ @login_required
282
+ def knowledge_graph(kb_id):
283
+ if not KnowledgebaseService.accessible(kb_id, current_user.id):
284
+ return get_json_result(
285
+ data=False,
286
+ message='No authorization.',
287
+ code=settings.RetCode.AUTHENTICATION_ERROR
288
+ )
289
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
290
+ req = {
291
+ "kb_id": [kb_id],
292
+ "knowledge_graph_kwd": ["graph"]
293
+ }
294
+ sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
295
+ obj = {"graph": {}, "mind_map": {}}
296
+ for id in sres.ids[:1]:
297
+ ty = sres.field[id]["knowledge_graph_kwd"]
298
+ try:
299
+ content_json = json.loads(sres.field[id]["content_with_weight"])
300
+ except Exception:
301
+ continue
302
+
303
+ obj[ty] = content_json
304
+
305
+ if "nodes" in obj["graph"]:
306
+ obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
307
+ if "edges" in obj["graph"]:
308
+ obj["graph"]["edges"] = sorted(obj["graph"]["edges"], key=lambda x: x.get("weight", 0), reverse=True)[:128]
309
+ return get_json_result(data=obj)
api/apps/sdk/dify_retrieval.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  from flask import request, jsonify
17
 
18
- from api.db import LLMType, ParserType
19
  from api.db.services.dialog_service import label_question
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
21
  from api.db.services.llm_service import LLMBundle
@@ -30,6 +30,7 @@ def retrieval(tenant_id):
30
  req = request.json
31
  question = req["query"]
32
  kb_id = req["knowledge_id"]
 
33
  retrieval_setting = req.get("retrieval_setting", {})
34
  similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
35
  top = int(retrieval_setting.get("top_k", 1024))
@@ -45,8 +46,7 @@ def retrieval(tenant_id):
45
 
46
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
47
 
48
- retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
49
- ranks = retr.retrieval(
50
  question,
51
  embd_mdl,
52
  kb.tenant_id,
@@ -58,6 +58,16 @@ def retrieval(tenant_id):
58
  top=top,
59
  rank_feature=label_question(question, [kb])
60
  )
 
 
 
 
 
 
 
 
 
 
61
  records = []
62
  for c in ranks["chunks"]:
63
  c.pop("vector", None)
 
15
  #
16
  from flask import request, jsonify
17
 
18
+ from api.db import LLMType
19
  from api.db.services.dialog_service import label_question
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
21
  from api.db.services.llm_service import LLMBundle
 
30
  req = request.json
31
  question = req["query"]
32
  kb_id = req["knowledge_id"]
33
+ use_kg = req.get("use_kg", False)
34
  retrieval_setting = req.get("retrieval_setting", {})
35
  similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
36
  top = int(retrieval_setting.get("top_k", 1024))
 
46
 
47
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
48
 
49
+ ranks = settings.retrievaler.retrieval(
 
50
  question,
51
  embd_mdl,
52
  kb.tenant_id,
 
58
  top=top,
59
  rank_feature=label_question(question, [kb])
60
  )
61
+
62
+ if use_kg:
63
+ ck = settings.kg_retrievaler.retrieval(question,
64
+ [tenant_id],
65
+ [kb_id],
66
+ embd_mdl,
67
+ LLMBundle(kb.tenant_id, LLMType.CHAT))
68
+ if ck["content_with_weight"]:
69
+ ranks["chunks"].insert(0, ck)
70
+
71
  records = []
72
  for c in ranks["chunks"]:
73
  c.pop("vector", None)
api/apps/sdk/doc.py CHANGED
@@ -1297,15 +1297,15 @@ def retrieval_test(tenant_id):
1297
  kb_ids = req["dataset_ids"]
1298
  if not isinstance(kb_ids, list):
1299
  return get_error_data_result("`dataset_ids` should be a list")
1300
- kbs = KnowledgebaseService.get_by_ids(kb_ids)
1301
  for id in kb_ids:
1302
  if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
1303
  return get_error_data_result(f"You don't own the dataset {id}.")
 
1304
  embd_nms = list(set([kb.embd_id for kb in kbs]))
1305
  if len(embd_nms) != 1:
1306
  return get_result(
1307
  message='Datasets use different embedding models."',
1308
- code=settings.RetCode.AUTHENTICATION_ERROR,
1309
  )
1310
  if "question" not in req:
1311
  return get_error_data_result("`question` is required.")
@@ -1313,6 +1313,7 @@ def retrieval_test(tenant_id):
1313
  size = int(req.get("page_size", 30))
1314
  question = req["question"]
1315
  doc_ids = req.get("document_ids", [])
 
1316
  if not isinstance(doc_ids, list):
1317
  return get_error_data_result("`documents` should be a list")
1318
  doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
@@ -1342,8 +1343,7 @@ def retrieval_test(tenant_id):
1342
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
1343
  question += keyword_extraction(chat_mdl, question)
1344
 
1345
- retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
1346
- ranks = retr.retrieval(
1347
  question,
1348
  embd_mdl,
1349
  kb.tenant_id,
@@ -1358,6 +1358,15 @@ def retrieval_test(tenant_id):
1358
  highlight=highlight,
1359
  rank_feature=label_question(question, kbs)
1360
  )
 
 
 
 
 
 
 
 
 
1361
  for c in ranks["chunks"]:
1362
  c.pop("vector", None)
1363
 
 
1297
  kb_ids = req["dataset_ids"]
1298
  if not isinstance(kb_ids, list):
1299
  return get_error_data_result("`dataset_ids` should be a list")
 
1300
  for id in kb_ids:
1301
  if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
1302
  return get_error_data_result(f"You don't own the dataset {id}.")
1303
+ kbs = KnowledgebaseService.get_by_ids(kb_ids)
1304
  embd_nms = list(set([kb.embd_id for kb in kbs]))
1305
  if len(embd_nms) != 1:
1306
  return get_result(
1307
  message='Datasets use different embedding models."',
1308
+ code=settings.RetCode.DATA_ERROR,
1309
  )
1310
  if "question" not in req:
1311
  return get_error_data_result("`question` is required.")
 
1313
  size = int(req.get("page_size", 30))
1314
  question = req["question"]
1315
  doc_ids = req.get("document_ids", [])
1316
+ use_kg = req.get("use_kg", False)
1317
  if not isinstance(doc_ids, list):
1318
  return get_error_data_result("`documents` should be a list")
1319
  doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
 
1343
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
1344
  question += keyword_extraction(chat_mdl, question)
1345
 
1346
+ ranks = settings.retrievaler.retrieval(
 
1347
  question,
1348
  embd_mdl,
1349
  kb.tenant_id,
 
1358
  highlight=highlight,
1359
  rank_feature=label_question(question, kbs)
1360
  )
1361
+ if use_kg:
1362
+ ck = settings.kg_retrievaler.retrieval(question,
1363
+ [k.tenant_id for k in kbs],
1364
+ kb_ids,
1365
+ embd_mdl,
1366
+ LLMBundle(kb.tenant_id, LLMType.CHAT))
1367
+ if ck["content_with_weight"]:
1368
+ ranks["chunks"].insert(0, ck)
1369
+
1370
  for c in ranks["chunks"]:
1371
  c.pop("vector", None)
1372
 
api/db/init_data.py CHANGED
@@ -133,7 +133,7 @@ def init_llm_factory():
133
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
134
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
135
  TenantService.filter_update([1 == 1], {
136
- "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
137
  ## insert openai two embedding models to the current openai user.
138
  # print("Start to insert 2 OpenAI embedding models...")
139
  tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
 
133
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
134
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
135
  TenantService.filter_update([1 == 1], {
136
+ "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
137
  ## insert openai two embedding models to the current openai user.
138
  # print("Start to insert 2 OpenAI embedding models...")
139
  tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
api/db/services/dialog_service.py CHANGED
@@ -197,8 +197,7 @@ def chat(dialog, messages, stream=True, **kwargs):
197
 
198
  embedding_model_name = embedding_list[0]
199
 
200
- is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
201
- retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
202
 
203
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
204
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@@ -275,6 +274,14 @@ def chat(dialog, messages, stream=True, **kwargs):
275
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
276
  rank_feature=label_question(" ".join(questions), kbs)
277
  )
 
 
 
 
 
 
 
 
278
 
279
  retrieval_ts = timer()
280
 
 
197
 
198
  embedding_model_name = embedding_list[0]
199
 
200
+ retriever = settings.retrievaler
 
201
 
202
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
203
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
 
274
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
275
  rank_feature=label_question(" ".join(questions), kbs)
276
  )
277
+ if prompt_config.get("use_kg"):
278
+ ck = settings.kg_retrievaler.retrieval(" ".join(questions),
279
+ tenant_ids,
280
+ dialog.kb_ids,
281
+ embd_mdl,
282
+ LLMBundle(dialog.tenant_id, LLMType.CHAT))
283
+ if ck["content_with_weight"]:
284
+ kbinfos["chunks"].insert(0, ck)
285
 
286
  retrieval_ts = timer()
287
 
api/db/services/document_service.py CHANGED
@@ -28,7 +28,7 @@ from peewee import fn
28
  from api.db.db_utils import bulk_insert_into_db
29
  from api import settings
30
  from api.utils import current_timestamp, get_format_time, get_uuid
31
- from graphrag.mind_map_extractor import MindMapExtractor
32
  from rag.settings import SVR_QUEUE_NAME
33
  from rag.utils.storage_factory import STORAGE_IMPL
34
  from rag.nlp import search, rag_tokenizer
@@ -105,8 +105,19 @@ class DocumentService(CommonService):
105
  @classmethod
106
  @DB.connection_context()
107
  def remove_document(cls, doc, tenant_id):
108
- settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
109
  cls.clear_chunk_num(doc.id)
 
 
 
 
 
 
 
 
 
 
 
 
110
  return cls.delete_by_id(doc.id)
111
 
112
  @classmethod
@@ -142,7 +153,7 @@ class DocumentService(CommonService):
142
  @DB.connection_context()
143
  def get_unfinished_docs(cls):
144
  fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
145
- cls.model.run]
146
  docs = cls.model.select(*fields) \
147
  .where(
148
  cls.model.status == StatusEnum.VALID.value,
@@ -295,9 +306,9 @@ class DocumentService(CommonService):
295
  Tenant.asr_id,
296
  Tenant.llm_id,
297
  )
298
- .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
299
- .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
300
- .where(cls.model.id == doc_id)
301
  )
302
  configs = configs.dicts()
303
  if not configs:
@@ -365,6 +376,12 @@ class DocumentService(CommonService):
365
  @classmethod
366
  @DB.connection_context()
367
  def update_progress(cls):
 
 
 
 
 
 
368
  docs = cls.get_unfinished_docs()
369
  for d in docs:
370
  try:
@@ -390,15 +407,27 @@ class DocumentService(CommonService):
390
  prg = -1
391
  status = TaskStatus.FAIL.value
392
  elif finished:
393
- if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
394
- " raptor") < 0:
395
- queue_raptor_tasks(d)
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  prg = 0.98 * len(tsks) / (len(tsks) + 1)
397
- msg.append("------ RAPTOR -------")
398
  else:
399
  status = TaskStatus.DONE.value
400
 
401
- msg = "\n".join(msg)
402
  info = {
403
  "process_duation": datetime.timestamp(
404
  datetime.now()) -
@@ -430,7 +459,7 @@ class DocumentService(CommonService):
430
  return False
431
 
432
 
433
- def queue_raptor_tasks(doc):
434
  chunking_config = DocumentService.get_chunking_config(doc["id"])
435
  hasher = xxhash.xxh64()
436
  for field in sorted(chunking_config.keys()):
@@ -443,15 +472,16 @@ def queue_raptor_tasks(doc):
443
  "doc_id": doc["id"],
444
  "from_page": 100000000,
445
  "to_page": 100000000,
446
- "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
447
  }
448
 
449
  task = new_task()
450
  for field in ["doc_id", "from_page", "to_page"]:
451
  hasher.update(str(task.get(field, "")).encode("utf-8"))
 
452
  task["digest"] = hasher.hexdigest()
453
  bulk_insert_into_db(Task, [task], True)
454
- task["type"] = "raptor"
455
  assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
456
 
457
 
@@ -489,7 +519,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
489
  ParserType.AUDIO.value: audio,
490
  ParserType.EMAIL.value: email
491
  }
492
- parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
493
  exe = ThreadPoolExecutor(max_workers=12)
494
  threads = []
495
  doc_nm = {}
@@ -592,4 +622,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
592
  DocumentService.increment_chunk_num(
593
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
594
 
595
- return [d["id"] for d, _ in files]
 
28
  from api.db.db_utils import bulk_insert_into_db
29
  from api import settings
30
  from api.utils import current_timestamp, get_format_time, get_uuid
31
+ from graphrag.general.mind_map_extractor import MindMapExtractor
32
  from rag.settings import SVR_QUEUE_NAME
33
  from rag.utils.storage_factory import STORAGE_IMPL
34
  from rag.nlp import search, rag_tokenizer
 
105
  @classmethod
106
  @DB.connection_context()
107
  def remove_document(cls, doc, tenant_id):
 
108
  cls.clear_chunk_num(doc.id)
109
+ try:
110
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
111
+ settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
112
+ {"remove": {"source_id": doc.id}},
113
+ search.index_name(tenant_id), doc.kb_id)
114
+ settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
115
+ {"removed_kwd": "Y"},
116
+ search.index_name(tenant_id), doc.kb_id)
117
+ settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
118
+ search.index_name(tenant_id), doc.kb_id)
119
+ except Exception:
120
+ pass
121
  return cls.delete_by_id(doc.id)
122
 
123
  @classmethod
 
153
  @DB.connection_context()
154
  def get_unfinished_docs(cls):
155
  fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
156
+ cls.model.run, cls.model.parser_id]
157
  docs = cls.model.select(*fields) \
158
  .where(
159
  cls.model.status == StatusEnum.VALID.value,
 
306
  Tenant.asr_id,
307
  Tenant.llm_id,
308
  )
309
+ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
310
+ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
311
+ .where(cls.model.id == doc_id)
312
  )
313
  configs = configs.dicts()
314
  if not configs:
 
376
  @classmethod
377
  @DB.connection_context()
378
  def update_progress(cls):
379
+ MSG = {
380
+ "raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
381
+ "graphrag": "Start Graph Extraction",
382
+ "graph_resolution": "Start Graph Resolution",
383
+ "graph_community": "Start Graph Community Reports Generation"
384
+ }
385
  docs = cls.get_unfinished_docs()
386
  for d in docs:
387
  try:
 
407
  prg = -1
408
  status = TaskStatus.FAIL.value
409
  elif finished:
410
+ m = "\n".join(sorted(msg))
411
+ if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
412
+ queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
413
+ prg = 0.98 * len(tsks) / (len(tsks) + 1)
414
+ elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
415
+ queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
416
+ prg = 0.98 * len(tsks) / (len(tsks) + 1)
417
+ elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
418
+ and d["parser_config"].get("graphrag", {}).get("resolution") \
419
+ and m.find(MSG["graph_resolution"]) < 0:
420
+ queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
421
+ prg = 0.98 * len(tsks) / (len(tsks) + 1)
422
+ elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
423
+ and d["parser_config"].get("graphrag", {}).get("community") \
424
+ and m.find(MSG["graph_community"]) < 0:
425
+ queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
426
  prg = 0.98 * len(tsks) / (len(tsks) + 1)
 
427
  else:
428
  status = TaskStatus.DONE.value
429
 
430
+ msg = "\n".join(sorted(msg))
431
  info = {
432
  "process_duation": datetime.timestamp(
433
  datetime.now()) -
 
459
  return False
460
 
461
 
462
+ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
463
  chunking_config = DocumentService.get_chunking_config(doc["id"])
464
  hasher = xxhash.xxh64()
465
  for field in sorted(chunking_config.keys()):
 
472
  "doc_id": doc["id"],
473
  "from_page": 100000000,
474
  "to_page": 100000000,
475
+ "progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
476
  }
477
 
478
  task = new_task()
479
  for field in ["doc_id", "from_page", "to_page"]:
480
  hasher.update(str(task.get(field, "")).encode("utf-8"))
481
+ hasher.update(ty.encode("utf-8"))
482
  task["digest"] = hasher.hexdigest()
483
  bulk_insert_into_db(Task, [task], True)
484
+ task["task_type"] = ty
485
  assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
486
 
487
 
 
519
  ParserType.AUDIO.value: audio,
520
  ParserType.EMAIL.value: email
521
  }
522
+ parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
523
  exe = ThreadPoolExecutor(max_workers=12)
524
  threads = []
525
  doc_nm = {}
 
622
  DocumentService.increment_chunk_num(
623
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
624
 
625
+ return [d["id"] for d, _ in files]
api/db/services/file_service.py CHANGED
@@ -401,7 +401,7 @@ class FileService(CommonService):
401
  ParserType.AUDIO.value: audio,
402
  ParserType.EMAIL.value: email
403
  }
404
- parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
405
  exe = ThreadPoolExecutor(max_workers=12)
406
  threads = []
407
  for file in file_objs:
 
401
  ParserType.AUDIO.value: audio,
402
  ParserType.EMAIL.value: email
403
  }
404
+ parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
405
  exe = ThreadPoolExecutor(max_workers=12)
406
  threads = []
407
  for file in file_objs:
api/db/services/task_service.py CHANGED
@@ -16,7 +16,6 @@
16
  import os
17
  import random
18
  import xxhash
19
- import bisect
20
  from datetime import datetime
21
 
22
  from api.db.db_utils import bulk_insert_into_db
@@ -183,7 +182,7 @@ class TaskService(CommonService):
183
  if os.environ.get("MACOS"):
184
  if info["progress_msg"]:
185
  task = cls.model.get_by_id(id)
186
- progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
187
  cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
188
  if "progress" in info:
189
  cls.model.update(progress=info["progress"]).where(
@@ -194,7 +193,7 @@ class TaskService(CommonService):
194
  with DB.lock("update_progress", -1):
195
  if info["progress_msg"]:
196
  task = cls.model.get_by_id(id)
197
- progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
198
  cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
199
  if "progress" in info:
200
  cls.model.update(progress=info["progress"]).where(
@@ -210,12 +209,12 @@ def queue_tasks(doc: dict, bucket: str, name: str):
210
 
211
  if doc["type"] == FileType.PDF.value:
212
  file_bin = STORAGE_IMPL.get(bucket, name)
213
- do_layout = doc["parser_config"].get("layout_recognize", True)
214
  pages = PdfParser.total_page_number(doc["name"], file_bin)
215
  page_size = doc["parser_config"].get("task_page_size", 12)
216
  if doc["parser_id"] == "paper":
217
  page_size = doc["parser_config"].get("task_page_size", 22)
218
- if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
219
  page_size = 10 ** 9
220
  page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
221
  for s, e in page_ranges:
@@ -243,6 +242,10 @@ def queue_tasks(doc: dict, bucket: str, name: str):
243
  for task in parse_task_array:
244
  hasher = xxhash.xxh64()
245
  for field in sorted(chunking_config.keys()):
 
 
 
 
246
  hasher.update(str(chunking_config[field]).encode("utf-8"))
247
  for field in ["doc_id", "from_page", "to_page"]:
248
  hasher.update(str(task.get(field, "")).encode("utf-8"))
@@ -276,20 +279,27 @@ def queue_tasks(doc: dict, bucket: str, name: str):
276
 
277
 
278
  def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
279
- idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")),
280
- key=lambda x: (x.get("from_page", 0), x.get("digest", "")))
 
 
 
 
 
 
281
  if idx >= len(prev_tasks):
282
  return 0
283
  prev_task = prev_tasks[idx]
284
- if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
285
  return 0
286
  task["chunk_ids"] = prev_task["chunk_ids"]
287
  task["progress"] = 1.0
288
- if "from_page" in task and "to_page" in task:
289
  task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
290
  else:
291
  task["progress_msg"] = ""
292
- task["progress_msg"] += "reused previous task's chunks."
 
293
  prev_task["chunk_ids"] = ""
294
 
295
  return len(task["chunk_ids"].split())
 
16
  import os
17
  import random
18
  import xxhash
 
19
  from datetime import datetime
20
 
21
  from api.db.db_utils import bulk_insert_into_db
 
182
  if os.environ.get("MACOS"):
183
  if info["progress_msg"]:
184
  task = cls.model.get_by_id(id)
185
+ progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
186
  cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
187
  if "progress" in info:
188
  cls.model.update(progress=info["progress"]).where(
 
193
  with DB.lock("update_progress", -1):
194
  if info["progress_msg"]:
195
  task = cls.model.get_by_id(id)
196
+ progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
197
  cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
198
  if "progress" in info:
199
  cls.model.update(progress=info["progress"]).where(
 
209
 
210
  if doc["type"] == FileType.PDF.value:
211
  file_bin = STORAGE_IMPL.get(bucket, name)
212
+ do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
213
  pages = PdfParser.total_page_number(doc["name"], file_bin)
214
  page_size = doc["parser_config"].get("task_page_size", 12)
215
  if doc["parser_id"] == "paper":
216
  page_size = doc["parser_config"].get("task_page_size", 22)
217
+ if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
218
  page_size = 10 ** 9
219
  page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
220
  for s, e in page_ranges:
 
242
  for task in parse_task_array:
243
  hasher = xxhash.xxh64()
244
  for field in sorted(chunking_config.keys()):
245
+ if field == "parser_config":
246
+ for k in ["raptor", "graphrag"]:
247
+ if k in chunking_config[field]:
248
+ del chunking_config[field][k]
249
  hasher.update(str(chunking_config[field]).encode("utf-8"))
250
  for field in ["doc_id", "from_page", "to_page"]:
251
  hasher.update(str(task.get(field, "")).encode("utf-8"))
 
279
 
280
 
281
  def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
282
+ idx = 0
283
+ while idx < len(prev_tasks):
284
+ prev_task = prev_tasks[idx]
285
+ if prev_task.get("from_page", 0) == task.get("from_page", 0) \
286
+ and prev_task.get("digest", 0) == task.get("digest", ""):
287
+ break
288
+ idx += 1
289
+
290
  if idx >= len(prev_tasks):
291
  return 0
292
  prev_task = prev_tasks[idx]
293
+ if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
294
  return 0
295
  task["chunk_ids"] = prev_task["chunk_ids"]
296
  task["progress"] = 1.0
297
+ if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
298
  task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
299
  else:
300
  task["progress_msg"] = ""
301
+ task["progress_msg"] = " ".join(
302
+ [datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
303
  prev_task["chunk_ids"] = ""
304
 
305
  return len(task["chunk_ids"].split())
api/utils/api_utils.py CHANGED
@@ -355,7 +355,7 @@ def get_parser_config(chunk_method, parser_config):
355
  if not chunk_method:
356
  chunk_method = "naive"
357
  key_mapping = {
358
- "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
359
  "raptor": {"use_raptor": False}},
360
  "qa": {"raptor": {"use_raptor": False}},
361
  "tag": None,
 
355
  if not chunk_method:
356
  chunk_method = "naive"
357
  key_mapping = {
358
+ "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC",
359
  "raptor": {"use_raptor": False}},
360
  "qa": {"raptor": {"use_raptor": False}},
361
  "tag": None,
conf/infinity_mapping.json CHANGED
@@ -25,9 +25,19 @@
25
  "weight_int": {"type": "integer", "default": 0},
26
  "weight_flt": {"type": "float", "default": 0.0},
27
  "rank_int": {"type": "integer", "default": 0},
 
28
  "available_int": {"type": "integer", "default": 1},
29
  "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
30
  "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
31
  "pagerank_fea": {"type": "integer", "default": 0},
32
- "tag_fea": {"type": "integer", "default": 0}
 
 
 
 
 
 
 
 
 
33
  }
 
25
  "weight_int": {"type": "integer", "default": 0},
26
  "weight_flt": {"type": "float", "default": 0.0},
27
  "rank_int": {"type": "integer", "default": 0},
28
+ "rank_flt": {"type": "float", "default": 0},
29
  "available_int": {"type": "integer", "default": 1},
30
  "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
31
  "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
32
  "pagerank_fea": {"type": "integer", "default": 0},
33
+ "tag_feas": {"type": "integer", "default": 0},
34
+
35
+ "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
36
+ "from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
37
+ "to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
38
+ "entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
39
+ "entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
40
+ "source_id": {"type": "varchar", "default": ""},
41
+ "n_hop_with_weight": {"type": "varchar", "default": ""},
42
+ "removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}
43
  }
graphrag/description_summary.py DELETED
@@ -1,146 +0,0 @@
1
- # Copyright (c) 2024 Microsoft Corporation.
2
- # Licensed under the MIT License
3
- """
4
- Reference:
5
- - [graphrag](https://github.com/microsoft/graphrag)
6
- """
7
-
8
- import json
9
- from dataclasses import dataclass
10
-
11
- from graphrag.extractor import Extractor
12
- from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
13
- from rag.llm.chat_model import Base as CompletionLLM
14
-
15
- from rag.utils import num_tokens_from_string
16
-
17
- SUMMARIZE_PROMPT = """
18
- You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
19
- Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
20
- Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
21
- If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
22
- Make sure it is written in third person, and include the entity names so we the have full context.
23
-
24
- #######
25
- -Data-
26
- Entities: {entity_name}
27
- Description List: {description_list}
28
- #######
29
- Output:
30
- """
31
-
32
- # Max token size for input prompts
33
- DEFAULT_MAX_INPUT_TOKENS = 4_000
34
- # Max token count for LLM answers
35
- DEFAULT_MAX_SUMMARY_LENGTH = 128
36
-
37
-
38
- @dataclass
39
- class SummarizationResult:
40
- """Unipartite graph extraction result class definition."""
41
-
42
- items: str | tuple[str, str]
43
- description: str
44
-
45
-
46
- class SummarizeExtractor(Extractor):
47
- """Unipartite graph extractor class definition."""
48
-
49
- _entity_name_key: str
50
- _input_descriptions_key: str
51
- _summarization_prompt: str
52
- _on_error: ErrorHandlerFn
53
- _max_summary_length: int
54
- _max_input_tokens: int
55
-
56
- def __init__(
57
- self,
58
- llm_invoker: CompletionLLM,
59
- entity_name_key: str | None = None,
60
- input_descriptions_key: str | None = None,
61
- summarization_prompt: str | None = None,
62
- on_error: ErrorHandlerFn | None = None,
63
- max_summary_length: int | None = None,
64
- max_input_tokens: int | None = None,
65
- ):
66
- """Init method definition."""
67
- # TODO: streamline construction
68
- self._llm = llm_invoker
69
- self._entity_name_key = entity_name_key or "entity_name"
70
- self._input_descriptions_key = input_descriptions_key or "description_list"
71
-
72
- self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
73
- self._on_error = on_error or (lambda _e, _s, _d: None)
74
- self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
75
- self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
76
-
77
- def __call__(
78
- self,
79
- items: str | tuple[str, str],
80
- descriptions: list[str],
81
- ) -> SummarizationResult:
82
- """Call method definition."""
83
- result = ""
84
- if len(descriptions) == 0:
85
- result = ""
86
- if len(descriptions) == 1:
87
- result = descriptions[0]
88
- else:
89
- result = self._summarize_descriptions(items, descriptions)
90
-
91
- return SummarizationResult(
92
- items=items,
93
- description=result or "",
94
- )
95
-
96
- def _summarize_descriptions(
97
- self, items: str | tuple[str, str], descriptions: list[str]
98
- ) -> str:
99
- """Summarize descriptions into a single description."""
100
- sorted_items = sorted(items) if isinstance(items, list) else items
101
-
102
- # Safety check, should always be a list
103
- if not isinstance(descriptions, list):
104
- descriptions = [descriptions]
105
-
106
- # Iterate over descriptions, adding all until the max input tokens is reached
107
- usable_tokens = self._max_input_tokens - num_tokens_from_string(
108
- self._summarization_prompt
109
- )
110
- descriptions_collected = []
111
- result = ""
112
-
113
- for i, description in enumerate(descriptions):
114
- usable_tokens -= num_tokens_from_string(description)
115
- descriptions_collected.append(description)
116
-
117
- # If buffer is full, or all descriptions have been added, summarize
118
- if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
119
- i == len(descriptions) - 1
120
- ):
121
- # Calculate result (final or partial)
122
- result = await self._summarize_descriptions_with_llm(
123
- sorted_items, descriptions_collected
124
- )
125
-
126
- # If we go for another loop, reset values to new
127
- if i != len(descriptions) - 1:
128
- descriptions_collected = [result]
129
- usable_tokens = (
130
- self._max_input_tokens
131
- - num_tokens_from_string(self._summarization_prompt)
132
- - num_tokens_from_string(result)
133
- )
134
-
135
- return result
136
-
137
- def _summarize_descriptions_with_llm(
138
- self, items: str | tuple[str, str] | list[str], descriptions: list[str]
139
- ):
140
- """Summarize descriptions using the LLM."""
141
- variables = {
142
- self._entity_name_key: json.dumps(items),
143
- self._input_descriptions_key: json.dumps(sorted(descriptions)),
144
- }
145
- text = perform_variable_replacements(self._summarization_prompt, variables=variables)
146
- return self._chat("", [{"role": "user", "content": text}])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphrag/entity_resolution.py CHANGED
@@ -16,18 +16,18 @@
16
  import logging
17
  import itertools
18
  import re
19
- import traceback
20
  from dataclasses import dataclass
21
- from typing import Any
22
 
23
  import networkx as nx
24
 
25
- from graphrag.extractor import Extractor
26
  from rag.nlp import is_english
27
  import editdistance
28
  from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
29
  from rag.llm.chat_model import Base as CompletionLLM
30
- from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
31
 
32
  DEFAULT_RECORD_DELIMITER = "##"
33
  DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@@ -37,8 +37,8 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
37
  @dataclass
38
  class EntityResolutionResult:
39
  """Entity resolution result class definition."""
40
-
41
- output: nx.Graph
42
 
43
 
44
  class EntityResolution(Extractor):
@@ -46,7 +46,6 @@ class EntityResolution(Extractor):
46
 
47
  _resolution_prompt: str
48
  _output_formatter_prompt: str
49
- _on_error: ErrorHandlerFn
50
  _record_delimiter_key: str
51
  _entity_index_delimiter_key: str
52
  _resolution_result_delimiter_key: str
@@ -54,21 +53,19 @@ class EntityResolution(Extractor):
54
  def __init__(
55
  self,
56
  llm_invoker: CompletionLLM,
57
- resolution_prompt: str | None = None,
58
- on_error: ErrorHandlerFn | None = None,
59
- record_delimiter_key: str | None = None,
60
- entity_index_delimiter_key: str | None = None,
61
- resolution_result_delimiter_key: str | None = None,
62
- input_text_key: str | None = None
63
  ):
 
64
  """Init method definition."""
65
  self._llm = llm_invoker
66
- self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
67
- self._on_error = on_error or (lambda _e, _s, _d: None)
68
- self._record_delimiter_key = record_delimiter_key or "record_delimiter"
69
- self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
70
- self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
71
- self._input_text_key = input_text_key or "input_text"
72
 
73
  def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
74
  """Call method definition."""
@@ -87,11 +84,11 @@ class EntityResolution(Extractor):
87
  }
88
 
89
  nodes = graph.nodes
90
- entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
91
  node_clusters = {entity_type: [] for entity_type in entity_types}
92
 
93
  for node in nodes:
94
- node_clusters[graph.nodes[node]['entity_type']].append(node)
95
 
96
  candidate_resolution = {entity_type: [] for entity_type in entity_types}
97
  for k, v in node_clusters.items():
@@ -128,44 +125,51 @@ class EntityResolution(Extractor):
128
  DEFAULT_RESOLUTION_RESULT_DELIMITER))
129
  for result_i in result:
130
  resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
131
- except Exception as e:
132
  logging.exception("error entity resolution")
133
- self._on_error(e, traceback.format_exc(), None)
134
 
135
  connect_graph = nx.Graph()
 
136
  connect_graph.add_edges_from(resolution_result)
137
  for sub_connect_graph in nx.connected_components(connect_graph):
138
  sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
139
  remove_nodes = list(sub_connect_graph.nodes)
140
  keep_node = remove_nodes.pop()
 
141
  for remove_node in remove_nodes:
 
142
  remove_node_neighbors = graph[remove_node]
143
- graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
144
- graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
145
  remove_node_neighbors = list(remove_node_neighbors)
146
  for remove_node_neighbor in remove_node_neighbors:
 
 
 
147
  if remove_node_neighbor == keep_node:
148
- graph.remove_edge(keep_node, remove_node)
 
 
 
149
  continue
150
  if graph.has_edge(keep_node, remove_node_neighbor):
151
- graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
152
- 'weight']
153
- graph[keep_node][remove_node_neighbor]['description'] += \
154
- graph[remove_node][remove_node_neighbor]['description']
155
- graph.remove_edge(remove_node, remove_node_neighbor)
156
  else:
157
- graph.add_edge(keep_node, remove_node_neighbor,
158
- weight=graph[remove_node][remove_node_neighbor]['weight'],
159
- description=graph[remove_node][remove_node_neighbor]['description'],
160
- source_id="")
161
- graph.remove_edge(remove_node, remove_node_neighbor)
 
 
 
 
 
 
 
162
  graph.remove_node(remove_node)
163
 
164
- for node_degree in graph.degree:
165
- graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
166
-
167
  return EntityResolutionResult(
168
- output=graph,
 
169
  )
170
 
171
  def _process_results(
 
16
  import logging
17
  import itertools
18
  import re
19
+ import time
20
  from dataclasses import dataclass
21
+ from typing import Any, Callable
22
 
23
  import networkx as nx
24
 
25
+ from graphrag.general.extractor import Extractor
26
  from rag.nlp import is_english
27
  import editdistance
28
  from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
29
  from rag.llm.chat_model import Base as CompletionLLM
30
+ from graphrag.utils import perform_variable_replacements
31
 
32
  DEFAULT_RECORD_DELIMITER = "##"
33
  DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
 
37
  @dataclass
38
  class EntityResolutionResult:
39
  """Entity resolution result class definition."""
40
+ graph: nx.Graph
41
+ removed_entities: list
42
 
43
 
44
  class EntityResolution(Extractor):
 
46
 
47
  _resolution_prompt: str
48
  _output_formatter_prompt: str
 
49
  _record_delimiter_key: str
50
  _entity_index_delimiter_key: str
51
  _resolution_result_delimiter_key: str
 
53
  def __init__(
54
  self,
55
  llm_invoker: CompletionLLM,
56
+ get_entity: Callable | None = None,
57
+ set_entity: Callable | None = None,
58
+ get_relation: Callable | None = None,
59
+ set_relation: Callable | None = None
 
 
60
  ):
61
+ super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
62
  """Init method definition."""
63
  self._llm = llm_invoker
64
+ self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
65
+ self._record_delimiter_key = "record_delimiter"
66
+ self._entity_index_dilimiter_key = "entity_index_delimiter"
67
+ self._resolution_result_delimiter_key = "resolution_result_delimiter"
68
+ self._input_text_key = "input_text"
 
69
 
70
  def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
71
  """Call method definition."""
 
84
  }
85
 
86
  nodes = graph.nodes
87
+ entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
88
  node_clusters = {entity_type: [] for entity_type in entity_types}
89
 
90
  for node in nodes:
91
+ node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
92
 
93
  candidate_resolution = {entity_type: [] for entity_type in entity_types}
94
  for k, v in node_clusters.items():
 
125
  DEFAULT_RESOLUTION_RESULT_DELIMITER))
126
  for result_i in result:
127
  resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
128
+ except Exception:
129
  logging.exception("error entity resolution")
 
130
 
131
  connect_graph = nx.Graph()
132
+ removed_entities = []
133
  connect_graph.add_edges_from(resolution_result)
134
  for sub_connect_graph in nx.connected_components(connect_graph):
135
  sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
136
  remove_nodes = list(sub_connect_graph.nodes)
137
  keep_node = remove_nodes.pop()
138
+ self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
139
  for remove_node in remove_nodes:
140
+ removed_entities.append(remove_node)
141
  remove_node_neighbors = graph[remove_node]
 
 
142
  remove_node_neighbors = list(remove_node_neighbors)
143
  for remove_node_neighbor in remove_node_neighbors:
144
+ rel = self._get_relation_(remove_node, remove_node_neighbor)
145
+ if graph.has_edge(remove_node, remove_node_neighbor):
146
+ graph.remove_edge(remove_node, remove_node_neighbor)
147
  if remove_node_neighbor == keep_node:
148
+ if graph.has_edge(keep_node, remove_node):
149
+ graph.remove_edge(keep_node, remove_node)
150
+ continue
151
+ if not rel:
152
  continue
153
  if graph.has_edge(keep_node, remove_node_neighbor):
154
+ self._merge_edges(keep_node, remove_node_neighbor, [rel])
 
 
 
 
155
  else:
156
+ pair = sorted([keep_node, remove_node_neighbor])
157
+ graph.add_edge(pair[0], pair[1], weight=rel['weight'])
158
+ self._set_relation_(pair[0], pair[1],
159
+ dict(
160
+ src_id=pair[0],
161
+ tgt_id=pair[1],
162
+ weight=rel['weight'],
163
+ description=rel['description'],
164
+ keywords=[],
165
+ source_id=rel.get("source_id", ""),
166
+ metadata={"created_at": time.time()}
167
+ ))
168
  graph.remove_node(remove_node)
169
 
 
 
 
170
  return EntityResolutionResult(
171
+ graph=graph,
172
+ removed_entities=removed_entities
173
  )
174
 
175
  def _process_results(
graphrag/extractor.py DELETED
@@ -1,34 +0,0 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from graphrag.utils import get_llm_cache, set_llm_cache
17
- from rag.llm.chat_model import Base as CompletionLLM
18
-
19
-
20
- class Extractor:
21
- _llm: CompletionLLM
22
-
23
- def __init__(self, llm_invoker: CompletionLLM):
24
- self._llm = llm_invoker
25
-
26
- def _chat(self, system, history, gen_conf):
27
- response = get_llm_cache(self._llm.llm_name, system, history, gen_conf)
28
- if response:
29
- return response
30
- response = self._llm.chat(system, history, gen_conf)
31
- if response.find("**ERROR**") >= 0:
32
- raise Exception(response)
33
- set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
34
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphrag/general/__init__.py ADDED
File without changes
graphrag/{claim_extractor.py → general/claim_extractor.py} RENAMED
@@ -15,8 +15,8 @@ from typing import Any
15
 
16
  import tiktoken
17
 
18
- from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
19
- from graphrag.extractor import Extractor
20
  from rag.llm.chat_model import Base as CompletionLLM
21
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
22
 
 
15
 
16
  import tiktoken
17
 
18
+ from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
19
+ from graphrag.general.extractor import Extractor
20
  from rag.llm.chat_model import Base as CompletionLLM
21
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
22
 
graphrag/{claim_prompt.py → general/claim_prompt.py} RENAMED
File without changes
graphrag/{community_report_prompt.py → general/community_report_prompt.py} RENAMED
File without changes
graphrag/{community_reports_extractor.py → general/community_reports_extractor.py} RENAMED
@@ -13,10 +13,10 @@ from typing import Callable
13
  from dataclasses import dataclass
14
  import networkx as nx
15
  import pandas as pd
16
- from graphrag import leiden
17
- from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
18
- from graphrag.extractor import Extractor
19
- from graphrag.leiden import add_community_info2graph
20
  from rag.llm.chat_model import Base as CompletionLLM
21
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
22
  from rag.utils import num_tokens_from_string
@@ -40,32 +40,43 @@ class CommunityReportsExtractor(Extractor):
40
  _max_report_length: int
41
 
42
  def __init__(
43
- self,
44
- llm_invoker: CompletionLLM,
45
- extraction_prompt: str | None = None,
46
- on_error: ErrorHandlerFn | None = None,
47
- max_report_length: int | None = None,
 
 
48
  ):
 
49
  """Init method definition."""
50
  self._llm = llm_invoker
51
- self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
52
- self._on_error = on_error or (lambda _e, _s, _d: None)
53
  self._max_report_length = max_report_length or 1500
54
 
55
  def __call__(self, graph: nx.Graph, callback: Callable | None = None):
 
 
 
56
  communities: dict[str, dict[str, list]] = leiden.run(graph, {})
57
  total = sum([len(comm.items()) for _, comm in communities.items()])
58
- relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
59
  res_str = []
60
  res_dict = []
61
  over, token_count = 0, 0
62
  st = timer()
63
  for level, comm in communities.items():
 
64
  for cm_id, ents in comm.items():
65
  weight = ents["weight"]
66
  ents = ents["nodes"]
67
- ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
68
- rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
 
 
 
 
 
 
69
 
70
  prompt_variables = {
71
  "entity_df": ent_df.to_csv(index_label="id"),
 
13
  from dataclasses import dataclass
14
  import networkx as nx
15
  import pandas as pd
16
+ from graphrag.general import leiden
17
+ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
18
+ from graphrag.general.extractor import Extractor
19
+ from graphrag.general.leiden import add_community_info2graph
20
  from rag.llm.chat_model import Base as CompletionLLM
21
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
22
  from rag.utils import num_tokens_from_string
 
40
  _max_report_length: int
41
 
42
  def __init__(
43
+ self,
44
+ llm_invoker: CompletionLLM,
45
+ get_entity: Callable | None = None,
46
+ set_entity: Callable | None = None,
47
+ get_relation: Callable | None = None,
48
+ set_relation: Callable | None = None,
49
+ max_report_length: int | None = None,
50
  ):
51
+ super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
52
  """Init method definition."""
53
  self._llm = llm_invoker
54
+ self._extraction_prompt = COMMUNITY_REPORT_PROMPT
 
55
  self._max_report_length = max_report_length or 1500
56
 
57
  def __call__(self, graph: nx.Graph, callback: Callable | None = None):
58
+ for node_degree in graph.degree:
59
+ graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
60
+
61
  communities: dict[str, dict[str, list]] = leiden.run(graph, {})
62
  total = sum([len(comm.items()) for _, comm in communities.items()])
 
63
  res_str = []
64
  res_dict = []
65
  over, token_count = 0, 0
66
  st = timer()
67
  for level, comm in communities.items():
68
+ logging.info(f"Level {level}: Community: {len(comm.keys())}")
69
  for cm_id, ents in comm.items():
70
  weight = ents["weight"]
71
  ents = ents["nodes"]
72
+ ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
73
+ ent_df["entity"] = ent_df["entity_name"]
74
+ del ent_df["entity_name"]
75
+ rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
76
+ rela_df["source"] = rela_df["src_id"]
77
+ rela_df["target"] = rela_df["tgt_id"]
78
+ del rela_df["src_id"]
79
+ del rela_df["tgt_id"]
80
 
81
  prompt_variables = {
82
  "entity_df": ent_df.to_csv(index_label="id"),
graphrag/{entity_embedding.py → general/entity_embedding.py} RENAMED
@@ -9,7 +9,7 @@ from typing import Any
9
  import numpy as np
10
  import networkx as nx
11
  from dataclasses import dataclass
12
- from graphrag.leiden import stable_largest_connected_component
13
  import graspologic as gc
14
 
15
 
 
9
  import numpy as np
10
  import networkx as nx
11
  from dataclasses import dataclass
12
+ from graphrag.general.leiden import stable_largest_connected_component
13
  import graspologic as gc
14
 
15
 
graphrag/general/extractor.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import logging
17
+ import os
18
+ from collections import defaultdict, Counter
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from copy import deepcopy
21
+ from typing import Callable
22
+
23
+ from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
24
+ from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
25
+ handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
26
+ from rag.llm.chat_model import Base as CompletionLLM
27
+ from rag.utils import truncate
28
+
29
+ GRAPH_FIELD_SEP = "<SEP>"
30
+ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
31
+ ENTITY_EXTRACTION_MAX_GLEANINGS = 2
32
+
33
+
34
+ class Extractor:
35
+ _llm: CompletionLLM
36
+
37
+ def __init__(
38
+ self,
39
+ llm_invoker: CompletionLLM,
40
+ language: str | None = "English",
41
+ entity_types: list[str] | None = None,
42
+ get_entity: Callable | None = None,
43
+ set_entity: Callable | None = None,
44
+ get_relation: Callable | None = None,
45
+ set_relation: Callable | None = None,
46
+ ):
47
+ self._llm = llm_invoker
48
+ self._language = language
49
+ self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
50
+ self._get_entity_ = get_entity
51
+ self._set_entity_ = set_entity
52
+ self._get_relation_ = get_relation
53
+ self._set_relation_ = set_relation
54
+
55
+ def _chat(self, system, history, gen_conf):
56
+ hist = deepcopy(history)
57
+ conf = deepcopy(gen_conf)
58
+ response = get_llm_cache(self._llm.llm_name, system, hist, conf)
59
+ if response:
60
+ return response
61
+ response = self._llm.chat(system, hist, conf)
62
+ if response.find("**ERROR**") >= 0:
63
+ raise Exception(response)
64
+ set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
65
+ return response
66
+
67
+ def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
68
+ maybe_nodes = defaultdict(list)
69
+ maybe_edges = defaultdict(list)
70
+ ent_types = [t.lower() for t in self._entity_types]
71
+ for record in records:
72
+ record_attributes = split_string_by_multi_markers(
73
+ record, [tuple_delimiter]
74
+ )
75
+
76
+ if_entities = handle_single_entity_extraction(
77
+ record_attributes, chunk_key
78
+ )
79
+ if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
80
+ maybe_nodes[if_entities["entity_name"]].append(if_entities)
81
+ continue
82
+
83
+ if_relation = handle_single_relationship_extraction(
84
+ record_attributes, chunk_key
85
+ )
86
+ if if_relation is not None:
87
+ maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
88
+ if_relation
89
+ )
90
+ return dict(maybe_nodes), dict(maybe_edges)
91
+
92
+ def __call__(
93
+ self, chunks: list[tuple[str, str]],
94
+ callback: Callable | None = None
95
+ ):
96
+
97
+ results = []
98
+ max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
99
+ with ThreadPoolExecutor(max_workers=max_workers) as exe:
100
+ threads = []
101
+ for i, (cid, ck) in enumerate(chunks):
102
+ threads.append(
103
+ exe.submit(self._process_single_content, (cid, ck)))
104
+
105
+ for i, _ in enumerate(threads):
106
+ n, r, tc = _.result()
107
+ if not isinstance(n, Exception):
108
+ results.append((n, r))
109
+ if callback:
110
+ callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
111
+ elif callback:
112
+ callback(msg="Knowledge graph extraction error:{}".format(str(n)))
113
+
114
+ maybe_nodes = defaultdict(list)
115
+ maybe_edges = defaultdict(list)
116
+ for m_nodes, m_edges in results:
117
+ for k, v in m_nodes.items():
118
+ maybe_nodes[k].extend(v)
119
+ for k, v in m_edges.items():
120
+ maybe_edges[tuple(sorted(k))].extend(v)
121
+ logging.info("Inserting entities into storage...")
122
+ all_entities_data = []
123
+ for en_nm, ents in maybe_nodes.items():
124
+ all_entities_data.append(self._merge_nodes(en_nm, ents))
125
+
126
+ logging.info("Inserting relationships into storage...")
127
+ all_relationships_data = []
128
+ for (src,tgt), rels in maybe_edges.items():
129
+ all_relationships_data.append(self._merge_edges(src, tgt, rels))
130
+
131
+ if not len(all_entities_data) and not len(all_relationships_data):
132
+ logging.warning(
133
+ "Didn't extract any entities and relationships, maybe your LLM is not working"
134
+ )
135
+
136
+ if not len(all_entities_data):
137
+ logging.warning("Didn't extract any entities")
138
+ if not len(all_relationships_data):
139
+ logging.warning("Didn't extract any relationships")
140
+
141
+ return all_entities_data, all_relationships_data
142
+
143
+ def _merge_nodes(self, entity_name: str, entities: list[dict]):
144
+ if not entities:
145
+ return
146
+ already_entity_types = []
147
+ already_source_ids = []
148
+ already_description = []
149
+
150
+ already_node = self._get_entity_(entity_name)
151
+ if already_node:
152
+ already_entity_types.append(already_node["entity_type"])
153
+ already_source_ids.extend(already_node["source_id"])
154
+ already_description.append(already_node["description"])
155
+
156
+ entity_type = sorted(
157
+ Counter(
158
+ [dp["entity_type"] for dp in entities] + already_entity_types
159
+ ).items(),
160
+ key=lambda x: x[1],
161
+ reverse=True,
162
+ )[0][0]
163
+ description = GRAPH_FIELD_SEP.join(
164
+ sorted(set([dp["description"] for dp in entities] + already_description))
165
+ )
166
+ already_source_ids = flat_uniq_list(entities, "source_id")
167
+ description = self._handle_entity_relation_summary(
168
+ entity_name, description
169
+ )
170
+ node_data = dict(
171
+ entity_type=entity_type,
172
+ description=description,
173
+ source_id=already_source_ids,
174
+ )
175
+ node_data["entity_name"] = entity_name
176
+ self._set_entity_(entity_name, node_data)
177
+ return node_data
178
+
179
+ def _merge_edges(
180
+ self,
181
+ src_id: str,
182
+ tgt_id: str,
183
+ edges_data: list[dict]
184
+ ):
185
+ if not edges_data:
186
+ return
187
+ already_weights = []
188
+ already_source_ids = []
189
+ already_description = []
190
+ already_keywords = []
191
+
192
+ relation = self._get_relation_(src_id, tgt_id)
193
+ if relation:
194
+ already_weights = [relation["weight"]]
195
+ already_source_ids = relation["source_id"]
196
+ already_description = [relation["description"]]
197
+ already_keywords = relation["keywords"]
198
+
199
+ weight = sum([dp["weight"] for dp in edges_data] + already_weights)
200
+ description = GRAPH_FIELD_SEP.join(
201
+ sorted(set([dp["description"] for dp in edges_data] + already_description))
202
+ )
203
+ keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
204
+ source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
205
+
206
+ for need_insert_id in [src_id, tgt_id]:
207
+ if self._get_entity_(need_insert_id):
208
+ continue
209
+ self._set_entity_(need_insert_id, {
210
+ "source_id": source_id,
211
+ "description": description,
212
+ "entity_type": 'UNKNOWN'
213
+ })
214
+ description = self._handle_entity_relation_summary(
215
+ f"({src_id}, {tgt_id})", description
216
+ )
217
+ edge_data = dict(
218
+ src_id=src_id,
219
+ tgt_id=tgt_id,
220
+ description=description,
221
+ keywords=keywords,
222
+ weight=weight,
223
+ source_id=source_id
224
+ )
225
+ self._set_relation_(src_id, tgt_id, edge_data)
226
+
227
+ return edge_data
228
+
229
+ def _handle_entity_relation_summary(
230
+ self,
231
+ entity_or_relation_name: str,
232
+ description: str
233
+ ) -> str:
234
+ summary_max_tokens = 512
235
+ use_description = truncate(description, summary_max_tokens)
236
+ prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
237
+ context_base = dict(
238
+ entity_name=entity_or_relation_name,
239
+ description_list=use_description.split(GRAPH_FIELD_SEP),
240
+ language=self._language,
241
+ )
242
+ use_prompt = prompt_template.format(**context_base)
243
+ logging.info(f"Trigger summary: {entity_or_relation_name}")
244
+ summary = self._chat(use_prompt, [{"role": "assistant", "content": "Output: "}], {"temperature": 0.8})
245
+ return summary
graphrag/general/graph_extractor.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Microsoft Corporation.
2
+ # Licensed under the MIT License
3
+ """
4
+ Reference:
5
+ - [graphrag](https://github.com/microsoft/graphrag)
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import Any, Callable
11
+ from dataclasses import dataclass
12
+ import tiktoken
13
+
14
+ from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
15
+ from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
16
+ from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
17
+ from rag.llm.chat_model import Base as CompletionLLM
18
+ import networkx as nx
19
+ from rag.utils import num_tokens_from_string
20
+
21
+ DEFAULT_TUPLE_DELIMITER = "<|>"
22
+ DEFAULT_RECORD_DELIMITER = "##"
23
+ DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
24
+
25
+
26
+ @dataclass
27
+ class GraphExtractionResult:
28
+ """Unipartite graph extraction result class definition."""
29
+
30
+ output: nx.Graph
31
+ source_docs: dict[Any, Any]
32
+
33
+
34
+ class GraphExtractor(Extractor):
35
+ """Unipartite graph extractor class definition."""
36
+
37
+ _join_descriptions: bool
38
+ _tuple_delimiter_key: str
39
+ _record_delimiter_key: str
40
+ _entity_types_key: str
41
+ _input_text_key: str
42
+ _completion_delimiter_key: str
43
+ _entity_name_key: str
44
+ _input_descriptions_key: str
45
+ _extraction_prompt: str
46
+ _summarization_prompt: str
47
+ _loop_args: dict[str, Any]
48
+ _max_gleanings: int
49
+ _on_error: ErrorHandlerFn
50
+
51
+ def __init__(
52
+ self,
53
+ llm_invoker: CompletionLLM,
54
+ language: str | None = "English",
55
+ entity_types: list[str] | None = None,
56
+ get_entity: Callable | None = None,
57
+ set_entity: Callable | None = None,
58
+ get_relation: Callable | None = None,
59
+ set_relation: Callable | None = None,
60
+ tuple_delimiter_key: str | None = None,
61
+ record_delimiter_key: str | None = None,
62
+ input_text_key: str | None = None,
63
+ entity_types_key: str | None = None,
64
+ completion_delimiter_key: str | None = None,
65
+ join_descriptions=True,
66
+ max_gleanings: int | None = None,
67
+ on_error: ErrorHandlerFn | None = None,
68
+ ):
69
+ super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
70
+ """Init method definition."""
71
+ # TODO: streamline construction
72
+ self._llm = llm_invoker
73
+ self._join_descriptions = join_descriptions
74
+ self._input_text_key = input_text_key or "input_text"
75
+ self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
76
+ self._record_delimiter_key = record_delimiter_key or "record_delimiter"
77
+ self._completion_delimiter_key = (
78
+ completion_delimiter_key or "completion_delimiter"
79
+ )
80
+ self._entity_types_key = entity_types_key or "entity_types"
81
+ self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
82
+ self._max_gleanings = (
83
+ max_gleanings
84
+ if max_gleanings is not None
85
+ else ENTITY_EXTRACTION_MAX_GLEANINGS
86
+ )
87
+ self._on_error = on_error or (lambda _e, _s, _d: None)
88
+ self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
89
+
90
+ # Construct the looping arguments
91
+ encoding = tiktoken.get_encoding("cl100k_base")
92
+ yes = encoding.encode("YES")
93
+ no = encoding.encode("NO")
94
+ self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
95
+
96
+ # Wire defaults into the prompt variables
97
+ self._prompt_variables = {
98
+ "entity_types": entity_types,
99
+ self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
100
+ self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
101
+ self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
102
+ self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
103
+ }
104
+
105
+ def _process_single_content(self,
106
+ chunk_key_dp: tuple[str, str]
107
+ ):
108
+ token_count = 0
109
+
110
+ chunk_key = chunk_key_dp[0]
111
+ content = chunk_key_dp[1]
112
+ variables = {
113
+ **self._prompt_variables,
114
+ self._input_text_key: content,
115
+ }
116
+ try:
117
+ gen_conf = {"temperature": 0.3}
118
+ hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
119
+ response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
120
+ token_count += num_tokens_from_string(hint_prompt + response)
121
+
122
+ results = response or ""
123
+ history = [{"role": "system", "content": hint_prompt}, {"role": "assistant", "content": response}]
124
+
125
+ # Repeat to ensure we maximize entity count
126
+ for i in range(self._max_gleanings):
127
+ text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
128
+ history.append({"role": "user", "content": text})
129
+ response = self._chat("", history, gen_conf)
130
+ token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
131
+ results += response or ""
132
+
133
+ # if this is the final glean, don't bother updating the continuation flag
134
+ if i >= self._max_gleanings - 1:
135
+ break
136
+ history.append({"role": "assistant", "content": response})
137
+ history.append({"role": "user", "content": LOOP_PROMPT})
138
+ continuation = self._chat("", history, self._loop_args)
139
+ token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
140
+ if continuation != "YES":
141
+ break
142
+
143
+ record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
144
+ tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
145
+ records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
146
+ records = [r for r in records if r.strip()]
147
+ maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
148
+ return maybe_nodes, maybe_edges, token_count
149
+ except Exception as e:
150
+ logging.exception("error extracting graph")
151
+ return e, None, None
152
+
153
+
154
+
graphrag/{graph_prompt.py → general/graph_prompt.py} RENAMED
@@ -106,4 +106,19 @@ Text: {input_text}
106
  Output:"""
107
 
108
  CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
109
- LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  Output:"""
107
 
108
  CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
109
+ LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
110
+
111
+ SUMMARIZE_DESCRIPTIONS_PROMPT = """
112
+ You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
113
+ Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
114
+ Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
115
+ If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
116
+ Make sure it is written in third person, and include the entity names so we the have full context.
117
+ Use {language} as output language.
118
+
119
+ #######
120
+ -Data-
121
+ Entities: {entity_name}
122
+ Description List: {description_list}
123
+ #######
124
+ """
graphrag/general/index.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import logging
18
+ from functools import reduce, partial
19
+ import networkx as nx
20
+
21
+ from api import settings
22
+ from graphrag.general.community_reports_extractor import CommunityReportsExtractor
23
+ from graphrag.entity_resolution import EntityResolution
24
+ from graphrag.general.extractor import Extractor
25
+ from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
26
+ from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
27
+ chunk_id, update_nodes_pagerank_nhop_neighbour
28
+ from rag.nlp import rag_tokenizer, search
29
+ from rag.utils.redis_conn import RedisDistributedLock
30
+
31
+
32
+ class Dealer:
33
+ def __init__(self,
34
+ extractor: Extractor,
35
+ tenant_id: str,
36
+ kb_id: str,
37
+ llm_bdl,
38
+ chunks: list[tuple[str, str]],
39
+ language,
40
+ entity_types=DEFAULT_ENTITY_TYPES,
41
+ embed_bdl=None,
42
+ callback=None
43
+ ):
44
+ docids = list(set([docid for docid,_ in chunks]))
45
+ self.llm_bdl = llm_bdl
46
+ self.embed_bdl = embed_bdl
47
+ ext = extractor(self.llm_bdl, language=language,
48
+ entity_types=entity_types,
49
+ get_entity=partial(get_entity, tenant_id, kb_id),
50
+ set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
51
+ get_relation=partial(get_relation, tenant_id, kb_id),
52
+ set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
53
+ )
54
+ ents, rels = ext(chunks, callback)
55
+ self.graph = nx.Graph()
56
+ for en in ents:
57
+ self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
58
+
59
+ for rel in rels:
60
+ self.graph.add_edge(
61
+ rel["src_id"],
62
+ rel["tgt_id"],
63
+ weight=rel["weight"],
64
+ #description=rel["description"]
65
+ )
66
+
67
+ with RedisDistributedLock(kb_id, 60*60):
68
+ old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
69
+ if old_graph is not None:
70
+ logging.info("Merge with an exiting graph...................")
71
+ self.graph = reduce(graph_merge, [old_graph, self.graph])
72
+ update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
73
+ if old_doc_ids:
74
+ docids.extend(old_doc_ids)
75
+ docids = list(set(docids))
76
+ set_graph(tenant_id, kb_id, self.graph, docids)
77
+
78
+
79
+ class WithResolution(Dealer):
80
+ def __init__(self,
81
+ tenant_id: str,
82
+ kb_id: str,
83
+ llm_bdl,
84
+ embed_bdl=None,
85
+ callback=None
86
+ ):
87
+ self.llm_bdl = llm_bdl
88
+ self.embed_bdl = embed_bdl
89
+
90
+ with RedisDistributedLock(kb_id, 60*60):
91
+ self.graph, doc_ids = get_graph(tenant_id, kb_id)
92
+ if not self.graph:
93
+ logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
94
+ if callback:
95
+ callback(-1, msg="Faild to fetch the graph.")
96
+ return
97
+
98
+ if callback:
99
+ callback(msg="Fetch the existing graph.")
100
+ er = EntityResolution(self.llm_bdl,
101
+ get_entity=partial(get_entity, tenant_id, kb_id),
102
+ set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
103
+ get_relation=partial(get_relation, tenant_id, kb_id),
104
+ set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
105
+ reso = er(self.graph)
106
+ self.graph = reso.graph
107
+ logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
108
+ if callback:
109
+ callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
110
+ update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
111
+ set_graph(tenant_id, kb_id, self.graph, doc_ids)
112
+
113
+ settings.docStoreConn.delete({
114
+ "knowledge_graph_kwd": "relation",
115
+ "kb_id": kb_id,
116
+ "from_entity_kwd": reso.removed_entities
117
+ }, search.index_name(tenant_id), kb_id)
118
+ settings.docStoreConn.delete({
119
+ "knowledge_graph_kwd": "relation",
120
+ "kb_id": kb_id,
121
+ "to_entity_kwd": reso.removed_entities
122
+ }, search.index_name(tenant_id), kb_id)
123
+ settings.docStoreConn.delete({
124
+ "knowledge_graph_kwd": "entity",
125
+ "kb_id": kb_id,
126
+ "entity_kwd": reso.removed_entities
127
+ }, search.index_name(tenant_id), kb_id)
128
+
129
+
130
+ class WithCommunity(Dealer):
131
+ def __init__(self,
132
+ tenant_id: str,
133
+ kb_id: str,
134
+ llm_bdl,
135
+ embed_bdl=None,
136
+ callback=None
137
+ ):
138
+
139
+ self.community_structure = None
140
+ self.community_reports = None
141
+ self.llm_bdl = llm_bdl
142
+ self.embed_bdl = embed_bdl
143
+
144
+ with RedisDistributedLock(kb_id, 60*60):
145
+ self.graph, doc_ids = get_graph(tenant_id, kb_id)
146
+ if not self.graph:
147
+ logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
148
+ if callback:
149
+ callback(-1, msg="Faild to fetch the graph.")
150
+ return
151
+ if callback:
152
+ callback(msg="Fetch the existing graph.")
153
+
154
+ cr = CommunityReportsExtractor(self.llm_bdl,
155
+ get_entity=partial(get_entity, tenant_id, kb_id),
156
+ set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
157
+ get_relation=partial(get_relation, tenant_id, kb_id),
158
+ set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
159
+ cr = cr(self.graph, callback=callback)
160
+ self.community_structure = cr.structured_output
161
+ self.community_reports = cr.output
162
+ set_graph(tenant_id, kb_id, self.graph, doc_ids)
163
+
164
+ if callback:
165
+ callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
166
+
167
+ settings.docStoreConn.delete({
168
+ "knowledge_graph_kwd": "community_report",
169
+ "kb_id": kb_id
170
+ }, search.index_name(tenant_id), kb_id)
171
+
172
+ for stru, rep in zip(self.community_structure, self.community_reports):
173
+ obj = {
174
+ "report": rep,
175
+ "evidences": "\n".join([f["explanation"] for f in stru["findings"]])
176
+ }
177
+ chunk = {
178
+ "docnm_kwd": stru["title"],
179
+ "title_tks": rag_tokenizer.tokenize(stru["title"]),
180
+ "content_with_weight": json.dumps(obj, ensure_ascii=False),
181
+ "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
182
+ "knowledge_graph_kwd": "community_report",
183
+ "weight_flt": stru["weight"],
184
+ "entities_kwd": stru["entities"],
185
+ "important_kwd": stru["entities"],
186
+ "kb_id": kb_id,
187
+ "source_id": doc_ids,
188
+ "available_int": 0
189
+ }
190
+ chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
191
+ #try:
192
+ # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
193
+ # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
194
+ #except Exception as e:
195
+ # logging.exception(f"Fail to embed entity relation: {e}")
196
+ settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
197
+
graphrag/{leiden.py → general/leiden.py} RENAMED
@@ -10,7 +10,6 @@ import html
10
  from typing import Any, cast
11
  from graspologic.partition import hierarchical_leiden
12
  from graspologic.utils import largest_connected_component
13
-
14
  import networkx as nx
15
  from networkx import is_empty
16
 
@@ -130,6 +129,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
130
  if not weights:
131
  continue
132
  max_weight = max(weights)
 
 
133
  for _, comm in result.items():
134
  comm["weight"] /= max_weight
135
 
 
10
  from typing import Any, cast
11
  from graspologic.partition import hierarchical_leiden
12
  from graspologic.utils import largest_connected_component
 
13
  import networkx as nx
14
  from networkx import is_empty
15
 
 
129
  if not weights:
130
  continue
131
  max_weight = max(weights)
132
+ if max_weight == 0:
133
+ continue
134
  for _, comm in result.items():
135
  comm["weight"] /= max_weight
136
 
graphrag/{mind_map_extractor.py → general/mind_map_extractor.py} RENAMED
@@ -23,8 +23,8 @@ from typing import Any
23
  from concurrent.futures import ThreadPoolExecutor
24
  from dataclasses import dataclass
25
 
26
- from graphrag.extractor import Extractor
27
- from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
28
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
29
  from rag.llm.chat_model import Base as CompletionLLM
30
  import markdown_to_json
 
23
  from concurrent.futures import ThreadPoolExecutor
24
  from dataclasses import dataclass
25
 
26
+ from graphrag.general.extractor import Extractor
27
+ from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
28
  from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
29
  from rag.llm.chat_model import Base as CompletionLLM
30
  import markdown_to_json
graphrag/{mind_map_prompt.py → general/mind_map_prompt.py} RENAMED
File without changes
graphrag/general/smoke.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import argparse
18
+ import json
19
+
20
+ import networkx as nx
21
+
22
+ from api import settings
23
+ from api.db import LLMType
24
+ from api.db.services.document_service import DocumentService
25
+ from api.db.services.knowledgebase_service import KnowledgebaseService
26
+ from api.db.services.llm_service import LLMBundle
27
+ from api.db.services.user_service import TenantService
28
+ from graphrag.general.index import WithCommunity, Dealer, WithResolution
29
+ from graphrag.light.graph_extractor import GraphExtractor
30
+ from rag.utils.redis_conn import RedisDistributedLock
31
+
32
+ settings.init_settings()
33
+
34
+ if __name__ == "__main__":
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
37
+ parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
38
+ args = parser.parse_args()
39
+ e, doc = DocumentService.get_by_id(args.doc_id)
40
+ if not e:
41
+ raise LookupError("Document not found.")
42
+ kb_id = doc.kb_id
43
+
44
+ chunks = [d["content_with_weight"] for d in
45
+ settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
46
+ fields=["content_with_weight"])]
47
+ chunks = [("x", c) for c in chunks]
48
+
49
+ RedisDistributedLock.clean_lock(kb_id)
50
+
51
+ _, tenant = TenantService.get_by_id(args.tenant_id)
52
+ llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
53
+ _, kb = KnowledgebaseService.get_by_id(kb_id)
54
+ embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
55
+
56
+ dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
57
+ print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
58
+
59
+ dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
60
+ dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
61
+
62
+ print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
63
+ print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
graphrag/graph_extractor.py DELETED
@@ -1,322 +0,0 @@
1
- # Copyright (c) 2024 Microsoft Corporation.
2
- # Licensed under the MIT License
3
- """
4
- Reference:
5
- - [graphrag](https://github.com/microsoft/graphrag)
6
- """
7
-
8
- import logging
9
- import numbers
10
- import re
11
- import traceback
12
- from typing import Any, Callable, Mapping
13
- from dataclasses import dataclass
14
- import tiktoken
15
-
16
- from graphrag.extractor import Extractor
17
- from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
18
- from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
19
- from rag.llm.chat_model import Base as CompletionLLM
20
- import networkx as nx
21
- from rag.utils import num_tokens_from_string
22
- from timeit import default_timer as timer
23
-
24
- DEFAULT_TUPLE_DELIMITER = "<|>"
25
- DEFAULT_RECORD_DELIMITER = "##"
26
- DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
27
- DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
28
- ENTITY_EXTRACTION_MAX_GLEANINGS = 1
29
-
30
-
31
- @dataclass
32
- class GraphExtractionResult:
33
- """Unipartite graph extraction result class definition."""
34
-
35
- output: nx.Graph
36
- source_docs: dict[Any, Any]
37
-
38
-
39
- class GraphExtractor(Extractor):
40
- """Unipartite graph extractor class definition."""
41
-
42
- _join_descriptions: bool
43
- _tuple_delimiter_key: str
44
- _record_delimiter_key: str
45
- _entity_types_key: str
46
- _input_text_key: str
47
- _completion_delimiter_key: str
48
- _entity_name_key: str
49
- _input_descriptions_key: str
50
- _extraction_prompt: str
51
- _summarization_prompt: str
52
- _loop_args: dict[str, Any]
53
- _max_gleanings: int
54
- _on_error: ErrorHandlerFn
55
-
56
- def __init__(
57
- self,
58
- llm_invoker: CompletionLLM,
59
- prompt: str | None = None,
60
- tuple_delimiter_key: str | None = None,
61
- record_delimiter_key: str | None = None,
62
- input_text_key: str | None = None,
63
- entity_types_key: str | None = None,
64
- completion_delimiter_key: str | None = None,
65
- join_descriptions=True,
66
- encoding_model: str | None = None,
67
- max_gleanings: int | None = None,
68
- on_error: ErrorHandlerFn | None = None,
69
- ):
70
- """Init method definition."""
71
- # TODO: streamline construction
72
- self._llm = llm_invoker
73
- self._join_descriptions = join_descriptions
74
- self._input_text_key = input_text_key or "input_text"
75
- self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
76
- self._record_delimiter_key = record_delimiter_key or "record_delimiter"
77
- self._completion_delimiter_key = (
78
- completion_delimiter_key or "completion_delimiter"
79
- )
80
- self._entity_types_key = entity_types_key or "entity_types"
81
- self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
82
- self._max_gleanings = (
83
- max_gleanings
84
- if max_gleanings is not None
85
- else ENTITY_EXTRACTION_MAX_GLEANINGS
86
- )
87
- self._on_error = on_error or (lambda _e, _s, _d: None)
88
- self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
89
-
90
- # Construct the looping arguments
91
- encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
92
- yes = encoding.encode("YES")
93
- no = encoding.encode("NO")
94
- self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
95
-
96
- def __call__(
97
- self, texts: list[str],
98
- prompt_variables: dict[str, Any] | None = None,
99
- callback: Callable | None = None
100
- ) -> GraphExtractionResult:
101
- """Call method definition."""
102
- if prompt_variables is None:
103
- prompt_variables = {}
104
- all_records: dict[int, str] = {}
105
- source_doc_map: dict[int, str] = {}
106
-
107
- # Wire defaults into the prompt variables
108
- prompt_variables = {
109
- **prompt_variables,
110
- self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
111
- or DEFAULT_TUPLE_DELIMITER,
112
- self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
113
- or DEFAULT_RECORD_DELIMITER,
114
- self._completion_delimiter_key: prompt_variables.get(
115
- self._completion_delimiter_key
116
- )
117
- or DEFAULT_COMPLETION_DELIMITER,
118
- self._entity_types_key: ",".join(
119
- prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
120
- ),
121
- }
122
-
123
- st = timer()
124
- total = len(texts)
125
- total_token_count = 0
126
- for doc_index, text in enumerate(texts):
127
- try:
128
- # Invoke the entity extraction
129
- result, token_count = self._process_document(text, prompt_variables)
130
- source_doc_map[doc_index] = text
131
- all_records[doc_index] = result
132
- total_token_count += token_count
133
- if callback:
134
- callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
135
- except Exception as e:
136
- if callback:
137
- callback(msg="Knowledge graph extraction error:{}".format(str(e)))
138
- logging.exception("error extracting graph")
139
- self._on_error(
140
- e,
141
- traceback.format_exc(),
142
- {
143
- "doc_index": doc_index,
144
- "text": text,
145
- },
146
- )
147
-
148
- output = self._process_results(
149
- all_records,
150
- prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
151
- prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
152
- )
153
-
154
- return GraphExtractionResult(
155
- output=output,
156
- source_docs=source_doc_map,
157
- )
158
-
159
- def _process_document(
160
- self, text: str, prompt_variables: dict[str, str]
161
- ) -> str:
162
- variables = {
163
- **prompt_variables,
164
- self._input_text_key: text,
165
- }
166
- token_count = 0
167
- text = perform_variable_replacements(self._extraction_prompt, variables=variables)
168
- gen_conf = {"temperature": 0.3}
169
- response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
170
- token_count = num_tokens_from_string(text + response)
171
-
172
- results = response or ""
173
- history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
174
-
175
- # Repeat to ensure we maximize entity count
176
- for i in range(self._max_gleanings):
177
- text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
178
- history.append({"role": "user", "content": text})
179
- response = self._chat("", history, gen_conf)
180
- results += response or ""
181
-
182
- # if this is the final glean, don't bother updating the continuation flag
183
- if i >= self._max_gleanings - 1:
184
- break
185
- history.append({"role": "assistant", "content": response})
186
- history.append({"role": "user", "content": LOOP_PROMPT})
187
- continuation = self._chat("", history, self._loop_args)
188
- if continuation != "YES":
189
- break
190
-
191
- return results, token_count
192
-
193
- def _process_results(
194
- self,
195
- results: dict[int, str],
196
- tuple_delimiter: str,
197
- record_delimiter: str,
198
- ) -> nx.Graph:
199
- """Parse the result string to create an undirected unipartite graph.
200
-
201
- Args:
202
- - results - dict of results from the extraction chain
203
- - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
204
- - record_delimiter - delimiter between records, default is '##'
205
- Returns:
206
- - output - unipartite graph in graphML format
207
- """
208
- graph = nx.Graph()
209
- for source_doc_id, extracted_data in results.items():
210
- records = [r.strip() for r in extracted_data.split(record_delimiter)]
211
-
212
- for record in records:
213
- record = re.sub(r"^\(|\)$", "", record.strip())
214
- record_attributes = record.split(tuple_delimiter)
215
-
216
- if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
217
- # add this record as a node in the G
218
- entity_name = clean_str(record_attributes[1].upper())
219
- entity_type = clean_str(record_attributes[2].upper())
220
- entity_description = clean_str(record_attributes[3])
221
-
222
- if entity_name in graph.nodes():
223
- node = graph.nodes[entity_name]
224
- if self._join_descriptions:
225
- node["description"] = "\n".join(
226
- list({
227
- *_unpack_descriptions(node),
228
- entity_description,
229
- })
230
- )
231
- else:
232
- if len(entity_description) > len(node["description"]):
233
- node["description"] = entity_description
234
- node["source_id"] = ", ".join(
235
- list({
236
- *_unpack_source_ids(node),
237
- str(source_doc_id),
238
- })
239
- )
240
- node["entity_type"] = (
241
- entity_type if entity_type != "" else node["entity_type"]
242
- )
243
- else:
244
- graph.add_node(
245
- entity_name,
246
- entity_type=entity_type,
247
- description=entity_description,
248
- source_id=str(source_doc_id),
249
- weight=1
250
- )
251
-
252
- if (
253
- record_attributes[0] == '"relationship"'
254
- and len(record_attributes) >= 5
255
- ):
256
- # add this record as edge
257
- source = clean_str(record_attributes[1].upper())
258
- target = clean_str(record_attributes[2].upper())
259
- edge_description = clean_str(record_attributes[3])
260
- edge_source_id = clean_str(str(source_doc_id))
261
- weight = (
262
- float(record_attributes[-1])
263
- if isinstance(record_attributes[-1], numbers.Number)
264
- else 1.0
265
- )
266
- if source not in graph.nodes():
267
- graph.add_node(
268
- source,
269
- entity_type="",
270
- description="",
271
- source_id=edge_source_id,
272
- weight=1
273
- )
274
- if target not in graph.nodes():
275
- graph.add_node(
276
- target,
277
- entity_type="",
278
- description="",
279
- source_id=edge_source_id,
280
- weight=1
281
- )
282
- if graph.has_edge(source, target):
283
- edge_data = graph.get_edge_data(source, target)
284
- if edge_data is not None:
285
- weight += edge_data["weight"]
286
- if self._join_descriptions:
287
- edge_description = "\n".join(
288
- list({
289
- *_unpack_descriptions(edge_data),
290
- edge_description,
291
- })
292
- )
293
- edge_source_id = ", ".join(
294
- list({
295
- *_unpack_source_ids(edge_data),
296
- str(source_doc_id),
297
- })
298
- )
299
- graph.add_edge(
300
- source,
301
- target,
302
- weight=weight,
303
- description=edge_description,
304
- source_id=edge_source_id,
305
- )
306
-
307
- for node_degree in graph.degree:
308
- graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
309
- return graph
310
-
311
-
312
- def _unpack_descriptions(data: Mapping) -> list[str]:
313
- value = data.get("description", None)
314
- return [] if value is None else value.split("\n")
315
-
316
-
317
- def _unpack_source_ids(data: Mapping) -> list[str]:
318
- value = data.get("source_id", None)
319
- return [] if value is None else value.split(", ")
320
-
321
-
322
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphrag/index.py DELETED
@@ -1,153 +0,0 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import logging
17
- import os
18
- from concurrent.futures import ThreadPoolExecutor
19
- import json
20
- from functools import reduce
21
- import networkx as nx
22
- from api.db import LLMType
23
- from api.db.services.llm_service import LLMBundle
24
- from api.db.services.user_service import TenantService
25
- from graphrag.community_reports_extractor import CommunityReportsExtractor
26
- from graphrag.entity_resolution import EntityResolution
27
- from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
28
- from graphrag.mind_map_extractor import MindMapExtractor
29
- from rag.nlp import rag_tokenizer
30
- from rag.utils import num_tokens_from_string
31
-
32
-
33
- def graph_merge(g1, g2):
34
- g = g2.copy()
35
- for n, attr in g1.nodes(data=True):
36
- if n not in g2.nodes():
37
- g.add_node(n, **attr)
38
- continue
39
-
40
- g.nodes[n]["weight"] += 1
41
- if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
42
- g.nodes[n]["description"] += "\n" + attr["description"]
43
-
44
- for source, target, attr in g1.edges(data=True):
45
- if g.has_edge(source, target):
46
- g[source][target].update({"weight": attr["weight"]+1})
47
- continue
48
- g.add_edge(source, target, **attr)
49
-
50
- for node_degree in g.degree:
51
- g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
52
- return g
53
-
54
-
55
- def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
56
- _, tenant = TenantService.get_by_id(tenant_id)
57
- llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
58
- ext = GraphExtractor(llm_bdl)
59
- left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
60
- left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
61
-
62
- assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
63
-
64
- BATCH_SIZE=4
65
- texts, graphs = [], []
66
- cnt = 0
67
- max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
68
- with ThreadPoolExecutor(max_workers=max_workers) as exe:
69
- threads = []
70
- for i in range(len(chunks)):
71
- tkn_cnt = num_tokens_from_string(chunks[i])
72
- if cnt+tkn_cnt >= left_token_count and texts:
73
- for b in range(0, len(texts), BATCH_SIZE):
74
- threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
75
- texts = []
76
- cnt = 0
77
- texts.append(chunks[i])
78
- cnt += tkn_cnt
79
- if texts:
80
- for b in range(0, len(texts), BATCH_SIZE):
81
- threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
82
-
83
- callback(0.5, "Extracting entities.")
84
- graphs = []
85
- for i, _ in enumerate(threads):
86
- graphs.append(_.result().output)
87
- callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
88
-
89
- graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
90
- er = EntityResolution(llm_bdl)
91
- graph = er(graph).output
92
-
93
- _chunks = chunks
94
- chunks = []
95
- for n, attr in graph.nodes(data=True):
96
- if attr.get("rank", 0) == 0:
97
- logging.debug(f"Ignore entity: {n}")
98
- continue
99
- chunk = {
100
- "name_kwd": n,
101
- "important_kwd": [n],
102
- "title_tks": rag_tokenizer.tokenize(n),
103
- "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
104
- "content_ltks": rag_tokenizer.tokenize(attr["description"]),
105
- "knowledge_graph_kwd": "entity",
106
- "rank_int": attr["rank"],
107
- "weight_int": attr["weight"]
108
- }
109
- chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
110
- chunks.append(chunk)
111
-
112
- callback(0.6, "Extracting community reports.")
113
- cr = CommunityReportsExtractor(llm_bdl)
114
- cr = cr(graph, callback=callback)
115
- for community, desc in zip(cr.structured_output, cr.output):
116
- chunk = {
117
- "title_tks": rag_tokenizer.tokenize(community["title"]),
118
- "content_with_weight": desc,
119
- "content_ltks": rag_tokenizer.tokenize(desc),
120
- "knowledge_graph_kwd": "community_report",
121
- "weight_flt": community["weight"],
122
- "entities_kwd": community["entities"],
123
- "important_kwd": community["entities"]
124
- }
125
- chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
126
- chunks.append(chunk)
127
-
128
- chunks.append(
129
- {
130
- "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
131
- "knowledge_graph_kwd": "graph"
132
- })
133
-
134
- callback(0.75, "Extracting mind graph.")
135
- mindmap = MindMapExtractor(llm_bdl)
136
- mg = mindmap(_chunks).output
137
- if not len(mg.keys()):
138
- return chunks
139
-
140
- logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
141
- chunks.append(
142
- {
143
- "content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
144
- "knowledge_graph_kwd": "mind_map"
145
- })
146
-
147
- return chunks
148
-
149
-
150
-
151
-
152
-
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphrag/light/__init__.py ADDED
File without changes
graphrag/light/graph_extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Microsoft Corporation.
2
+ # Licensed under the MIT License
3
+ """
4
+ Reference:
5
+ - [graphrag](https://github.com/microsoft/graphrag)
6
+ """
7
+ import logging
8
+ import re
9
+ from typing import Any, Callable
10
+ from dataclasses import dataclass
11
+ from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
12
+ from graphrag.light.graph_prompt import PROMPTS
13
+ from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
14
+ from rag.llm.chat_model import Base as CompletionLLM
15
+ import networkx as nx
16
+ from rag.utils import num_tokens_from_string
17
+
18
+
19
+ @dataclass
20
+ class GraphExtractionResult:
21
+ """Unipartite graph extraction result class definition."""
22
+
23
+ output: nx.Graph
24
+ source_docs: dict[Any, Any]
25
+
26
+
27
+ class GraphExtractor(Extractor):
28
+
29
+ _max_gleanings: int
30
+
31
+ def __init__(
32
+ self,
33
+ llm_invoker: CompletionLLM,
34
+ language: str | None = "English",
35
+ entity_types: list[str] | None = None,
36
+ get_entity: Callable | None = None,
37
+ set_entity: Callable | None = None,
38
+ get_relation: Callable | None = None,
39
+ set_relation: Callable | None = None,
40
+ example_number: int = 2,
41
+ max_gleanings: int | None = None,
42
+ ):
43
+ super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
44
+ """Init method definition."""
45
+ self._max_gleanings = (
46
+ max_gleanings
47
+ if max_gleanings is not None
48
+ else ENTITY_EXTRACTION_MAX_GLEANINGS
49
+ )
50
+ self._example_number = example_number
51
+ examples = "\n".join(
52
+ PROMPTS["entity_extraction_examples"][: int(self._example_number)]
53
+ )
54
+
55
+ example_context_base = dict(
56
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
57
+ record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
58
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
59
+ entity_types=",".join(self._entity_types),
60
+ language=self._language,
61
+ )
62
+ # add example's format
63
+ examples = examples.format(**example_context_base)
64
+
65
+ self._entity_extract_prompt = PROMPTS["entity_extraction"]
66
+ self._context_base = dict(
67
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
68
+ record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
69
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
70
+ entity_types=",".join(self._entity_types),
71
+ examples=examples,
72
+ language=self._language,
73
+ )
74
+
75
+ self._continue_prompt = PROMPTS["entiti_continue_extraction"]
76
+ self._if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
77
+
78
+ self._left_token_count = llm_invoker.max_length - num_tokens_from_string(
79
+ self._entity_extract_prompt.format(
80
+ **self._context_base, input_text="{input_text}"
81
+ ).format(**self._context_base, input_text="")
82
+ )
83
+ self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
84
+
85
+ def _process_single_content(self, chunk_key_dp: tuple[str, str]):
86
+ token_count = 0
87
+ chunk_key = chunk_key_dp[0]
88
+ content = chunk_key_dp[1]
89
+ hint_prompt = self._entity_extract_prompt.format(
90
+ **self._context_base, input_text="{input_text}"
91
+ ).format(**self._context_base, input_text=content)
92
+
93
+ try:
94
+ gen_conf = {"temperature": 0.3}
95
+ final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
96
+ token_count += num_tokens_from_string(hint_prompt + final_result)
97
+ history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
98
+ for now_glean_index in range(self._max_gleanings):
99
+ glean_result = self._chat(self._continue_prompt, history, gen_conf)
100
+ token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + glean_result + self._continue_prompt)
101
+ history += pack_user_ass_to_openai_messages(self._continue_prompt, glean_result)
102
+ final_result += glean_result
103
+ if now_glean_index == self._max_gleanings - 1:
104
+ break
105
+
106
+ if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
107
+ token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
108
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
109
+ if if_loop_result != "yes":
110
+ break
111
+
112
+ records = split_string_by_multi_markers(
113
+ final_result,
114
+ [self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
115
+ )
116
+ rcds = []
117
+ for record in records:
118
+ record = re.search(r"\((.*)\)", record)
119
+ if record is None:
120
+ continue
121
+ rcds.append(record.group(1))
122
+ records = rcds
123
+ maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
124
+ return maybe_nodes, maybe_edges, token_count
125
+ except Exception as e:
126
+ logging.exception("error extracting graph")
127
+ return e, None, None
graphrag/light/graph_prompt.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the MIT License
2
+ """
3
+ Reference:
4
+ - [LightRag](https://github.com/HKUDS/LightRAG)
5
+ """
6
+
7
+
8
+ PROMPTS = {}
9
+
10
+ PROMPTS["DEFAULT_LANGUAGE"] = "English"
11
+ PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
12
+ PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
13
+ PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
14
+ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
15
+
16
+ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
17
+
18
+ PROMPTS["entity_extraction"] = """-Goal-
19
+ Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
20
+ Use {language} as output language.
21
+
22
+ -Steps-
23
+ 1. Identify all entities. For each identified entity, extract the following information:
24
+ - entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
25
+ - entity_type: One of the following types: [{entity_types}]
26
+ - entity_description: Comprehensive description of the entity's attributes and activities
27
+ Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
28
+
29
+ 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
30
+ For each pair of related entities, extract the following information:
31
+ - source_entity: name of the source entity, as identified in step 1
32
+ - target_entity: name of the target entity, as identified in step 1
33
+ - relationship_description: explanation as to why you think the source entity and the target entity are related to each other
34
+ - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
35
+ - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
36
+ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
37
+
38
+ 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
39
+ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
40
+
41
+ 4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
42
+
43
+ 5. When finished, output {completion_delimiter}
44
+
45
+ ######################
46
+ -Examples-
47
+ ######################
48
+ {examples}
49
+
50
+ #############################
51
+ -Real Data-
52
+ ######################
53
+ Entity_types: {entity_types}
54
+ Text: {input_text}
55
+ ######################
56
+ """
57
+
58
+ PROMPTS["entity_extraction_examples"] = [
59
+ """Example 1:
60
+
61
+ Entity_types: [person, technology, mission, organization, location]
62
+ Text:
63
+ while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
64
+
65
+ Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
66
+
67
+ The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
68
+
69
+ It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
70
+ ################
71
+ Output:
72
+ ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
73
+ ("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
74
+ ("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
75
+ ("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
76
+ ("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
77
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
78
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
79
+ ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
80
+ ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
81
+ ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
82
+ ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
83
+ #############################""",
84
+ """Example 2:
85
+
86
+ Entity_types: [person, technology, mission, organization, location]
87
+ Text:
88
+ They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
89
+
90
+ Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
91
+
92
+ Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
93
+ #############
94
+ Output:
95
+ ("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
96
+ ("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
97
+ ("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
98
+ ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
99
+ ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
100
+ ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
101
+ #############################""",
102
+ """Example 3:
103
+
104
+ Entity_types: [person, role, technology, organization, event, location, concept]
105
+ Text:
106
+ their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
107
+
108
+ "It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
109
+
110
+ Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
111
+
112
+ Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
113
+
114
+ The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
115
+ #############
116
+ Output:
117
+ ("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
118
+ ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
119
+ ("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
120
+ ("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
121
+ ("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
122
+ ("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
123
+ ("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
124
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
125
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
126
+ ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
127
+ ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
128
+ #############################""",
129
+ ]
130
+
131
+ PROMPTS[
132
+ "entiti_continue_extraction"
133
+ ] = """MANY entities were missed in the last extraction. Add them below using the same format:
134
+ """
135
+
136
+ PROMPTS[
137
+ "entiti_if_loop_extraction"
138
+ ] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
139
+ """
140
+
141
+ PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
142
+
143
+ PROMPTS["rag_response"] = """---Role---
144
+
145
+ You are a helpful assistant responding to questions about data in the tables provided.
146
+
147
+
148
+ ---Goal---
149
+
150
+ Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
151
+ If you don't know the answer, just say so. Do not make anything up.
152
+ Do not include information where the supporting evidence for it is not provided.
153
+
154
+ When handling relationships with timestamps:
155
+ 1. Each relationship has a "created_at" timestamp indicating when we acquired this knowledge
156
+ 2. When encountering conflicting relationships, consider both the semantic content and the timestamp
157
+ 3. Don't automatically prefer the most recently created relationships - use judgment based on the context
158
+ 4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
159
+
160
+ ---Target response length and format---
161
+
162
+ {response_type}
163
+
164
+ ---Data tables---
165
+
166
+ {context_data}
167
+
168
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown."""
169
+
170
+ PROMPTS["naive_rag_response"] = """---Role---
171
+
172
+ You are a helpful assistant responding to questions about documents provided.
173
+
174
+
175
+ ---Goal---
176
+
177
+ Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
178
+ If you don't know the answer, just say so. Do not make anything up.
179
+ Do not include information where the supporting evidence for it is not provided.
180
+
181
+ When handling content with timestamps:
182
+ 1. Each piece of content has a "created_at" timestamp indicating when we acquired this knowledge
183
+ 2. When encountering conflicting information, consider both the content and the timestamp
184
+ 3. Don't automatically prefer the most recent content - use judgment based on the context
185
+ 4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
186
+
187
+ ---Target response length and format---
188
+
189
+ {response_type}
190
+
191
+ ---Documents---
192
+
193
+ {content_data}
194
+
195
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
196
+ """
197
+
198
+ PROMPTS[
199
+ "similarity_check"
200
+ ] = """Please analyze the similarity between these two questions:
201
+
202
+ Question 1: {original_prompt}
203
+ Question 2: {cached_prompt}
204
+
205
+ Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
206
+ 1. Whether these two questions are semantically similar
207
+ 2. Whether the answer to Question 2 can be used to answer Question 1
208
+ Similarity score criteria:
209
+ 0: Completely unrelated or answer cannot be reused, including but not limited to:
210
+ - The questions have different topics
211
+ - The locations mentioned in the questions are different
212
+ - The times mentioned in the questions are different
213
+ - The specific individuals mentioned in the questions are different
214
+ - The specific events mentioned in the questions are different
215
+ - The background information in the questions is different
216
+ - The key conditions in the questions are different
217
+ 1: Identical and answer can be directly reused
218
+ 0.5: Partially related and answer needs modification to be used
219
+ Return only a number between 0-1, without any additional content.
220
+ """
221
+
222
+ PROMPTS["mix_rag_response"] = """---Role---
223
+
224
+ You are a professional assistant responsible for answering questions based on knowledge graph and textual information. Please respond in the same language as the user's question.
225
+
226
+ ---Goal---
227
+
228
+ Generate a concise response that summarizes relevant points from the provided information. If you don't know the answer, just say so. Do not make anything up or include information where the supporting evidence is not provided.
229
+
230
+ When handling information with timestamps:
231
+ 1. Each piece of information (both relationships and content) has a "created_at" timestamp indicating when we acquired this knowledge
232
+ 2. When encountering conflicting information, consider both the content/relationship and the timestamp
233
+ 3. Don't automatically prefer the most recent information - use judgment based on the context
234
+ 4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
235
+
236
+ ---Data Sources---
237
+
238
+ 1. Knowledge Graph Data:
239
+ {kg_context}
240
+
241
+ 2. Vector Data:
242
+ {vector_context}
243
+
244
+ ---Response Requirements---
245
+
246
+ - Target format and length: {response_type}
247
+ - Use markdown formatting with appropriate section headings
248
+ - Aim to keep content around 3 paragraphs for conciseness
249
+ - Each paragraph should be under a relevant section heading
250
+ - Each section should focus on one main point or aspect of the answer
251
+ - Use clear and descriptive section titles that reflect the content
252
+ - List up to 5 most important reference sources at the end under "References", clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (VD)
253
+ Format: [KG/VD] Source content
254
+
255
+ Add sections and commentary to the response as appropriate for the length and format. If the provided information is insufficient to answer the question, clearly state that you don't know or cannot provide an answer in the same language as the user's question."""
graphrag/{smoke.py → light/smoke.py} RENAMED
@@ -16,11 +16,19 @@
16
 
17
  import argparse
18
  import json
19
- from graphrag import leiden
20
- from graphrag.community_reports_extractor import CommunityReportsExtractor
21
- from graphrag.entity_resolution import EntityResolution
22
- from graphrag.graph_extractor import GraphExtractor
23
- from graphrag.leiden import add_community_info2graph
 
 
 
 
 
 
 
 
24
 
25
  if __name__ == "__main__":
26
  parser = argparse.ArgumentParser()
@@ -28,28 +36,23 @@ if __name__ == "__main__":
28
  parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
29
  args = parser.parse_args()
30
 
31
- from api.db import LLMType
32
- from api.db.services.llm_service import LLMBundle
33
- from api import settings
34
- from api.db.services.knowledgebase_service import KnowledgebaseService
35
-
36
- kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
37
 
38
- ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
39
- docs = [d["content_with_weight"] for d in
40
- settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
41
- graph = ex(docs)
42
 
43
- er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
44
- graph = er(graph.output)
45
 
46
- comm = leiden.run(graph.output, {})
47
- add_community_info2graph(graph.output, comm)
 
 
48
 
49
- # print(json.dumps(nx.node_link_data(graph.output), ensure_ascii=False,indent=2))
50
- print(json.dumps(comm, ensure_ascii=False, indent=2))
51
 
52
- cr = CommunityReportsExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
53
- cr = cr(graph.output)
54
- print("------------------ COMMUNITY REPORT ----------------------\n", cr.output)
55
- print(json.dumps(cr.structured_output, ensure_ascii=False, indent=2))
 
16
 
17
  import argparse
18
  import json
19
+ from api import settings
20
+ import networkx as nx
21
+
22
+ from api.db import LLMType
23
+ from api.db.services.document_service import DocumentService
24
+ from api.db.services.knowledgebase_service import KnowledgebaseService
25
+ from api.db.services.llm_service import LLMBundle
26
+ from api.db.services.user_service import TenantService
27
+ from graphrag.general.index import Dealer
28
+ from graphrag.light.graph_extractor import GraphExtractor
29
+ from rag.utils.redis_conn import RedisDistributedLock
30
+
31
+ settings.init_settings()
32
 
33
  if __name__ == "__main__":
34
  parser = argparse.ArgumentParser()
 
36
  parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
37
  args = parser.parse_args()
38
 
39
+ e, doc = DocumentService.get_by_id(args.doc_id)
40
+ if not e:
41
+ raise LookupError("Document not found.")
42
+ kb_id = doc.kb_id
 
 
43
 
44
+ chunks = [d["content_with_weight"] for d in
45
+ settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
46
+ fields=["content_with_weight"])]
47
+ chunks = [("x", c) for c in chunks]
48
 
49
+ RedisDistributedLock.clean_lock(kb_id)
 
50
 
51
+ _, tenant = TenantService.get_by_id(args.tenant_id)
52
+ llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
53
+ _, kb = KnowledgebaseService.get_by_id(kb_id)
54
+ embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
55
 
56
+ dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
 
57
 
58
+ print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
 
 
 
graphrag/query_analyze_prompt.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the MIT License
2
+ """
3
+ Reference:
4
+ - [LightRag](https://github.com/HKUDS/LightRAG)
5
+ - [MiniRAG](https://github.com/HKUDS/MiniRAG)
6
+ """
7
+ PROMPTS = {}
8
+
9
+ PROMPTS["minirag_query2kwd"] = """---Role---
10
+
11
+ You are a helpful assistant tasked with identifying both answer-type and low-level keywords in the user's query.
12
+
13
+ ---Goal---
14
+
15
+ Given the query, list both answer-type and low-level keywords.
16
+ answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms.
17
+ The answer_type_keywords must be selected from Answer type pool.
18
+ This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples.
19
+
20
+ ---Instructions---
21
+
22
+ - Output the keywords in JSON format.
23
+ - The JSON should have three keys:
24
+ - "answer_type_keywords" for the types of the answer. In this list, the types with the highest likelihood should be placed at the forefront. No more than 3.
25
+ - "entities_from_query" for specific entities or details. It must be extracted from the query.
26
+ ######################
27
+ -Examples-
28
+ ######################
29
+ Example 1:
30
+
31
+ Query: "How does international trade influence global economic stability?"
32
+ Answer type pool: {{
33
+ 'PERSONAL LIFE': ['FAMILY TIME', 'HOME MAINTENANCE'],
34
+ 'STRATEGY': ['MARKETING PLAN', 'BUSINESS EXPANSION'],
35
+ 'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
36
+ 'PERSON': ['JANE DOE', 'JOHN SMITH'],
37
+ 'FOOD': ['PASTA', 'SUSHI'],
38
+ 'EMOTION': ['HAPPINESS', 'ANGER'],
39
+ 'PERSONAL EXPERIENCE': ['TRAVEL ABROAD', 'STUDYING ABROAD'],
40
+ 'INTERACTION': ['TEAM MEETING', 'NETWORKING EVENT'],
41
+ 'BEVERAGE': ['COFFEE', 'TEA'],
42
+ 'PLAN': ['ANNUAL BUDGET', 'PROJECT TIMELINE'],
43
+ 'GEO': ['NEW YORK CITY', 'SOUTH AFRICA'],
44
+ 'GEAR': ['CAMPING TENT', 'CYCLING HELMET'],
45
+ 'EMOJI': ['🎉', '🚀'],
46
+ 'BEHAVIOR': ['POSITIVE FEEDBACK', 'NEGATIVE CRITICISM'],
47
+ 'TONE': ['FORMAL', 'INFORMAL'],
48
+ 'LOCATION': ['DOWNTOWN', 'SUBURBS']
49
+ }}
50
+ ################
51
+ Output:
52
+ {{
53
+ "answer_type_keywords": ["STRATEGY","PERSONAL LIFE"],
54
+ "entities_from_query": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
55
+ }}
56
+ #############################
57
+ Example 2:
58
+
59
+ Query: "When was SpaceX's first rocket launch?"
60
+ Answer type pool: {{
61
+ 'DATE AND TIME': ['2023-10-10 10:00', 'THIS AFTERNOON'],
62
+ 'ORGANIZATION': ['GLOBAL INITIATIVES CORPORATION', 'LOCAL COMMUNITY CENTER'],
63
+ 'PERSONAL LIFE': ['DAILY EXERCISE ROUTINE', 'FAMILY VACATION PLANNING'],
64
+ 'STRATEGY': ['NEW PRODUCT LAUNCH', 'YEAR-END SALES BOOST'],
65
+ 'SERVICE FACILITATION': ['REMOTE IT SUPPORT', 'ON-SITE TRAINING SESSIONS'],
66
+ 'PERSON': ['ALEXANDER HAMILTON', 'MARIA CURIE'],
67
+ 'FOOD': ['GRILLED SALMON', 'VEGETARIAN BURRITO'],
68
+ 'EMOTION': ['EXCITEMENT', 'DISAPPOINTMENT'],
69
+ 'PERSONAL EXPERIENCE': ['BIRTHDAY CELEBRATION', 'FIRST MARATHON'],
70
+ 'INTERACTION': ['OFFICE WATER COOLER CHAT', 'ONLINE FORUM DEBATE'],
71
+ 'BEVERAGE': ['ICED COFFEE', 'GREEN SMOOTHIE'],
72
+ 'PLAN': ['WEEKLY MEETING SCHEDULE', 'MONTHLY BUDGET OVERVIEW'],
73
+ 'GEO': ['MOUNT EVEREST BASE CAMP', 'THE GREAT BARRIER REEF'],
74
+ 'GEAR': ['PROFESSIONAL CAMERA EQUIPMENT', 'OUTDOOR HIKING GEAR'],
75
+ 'EMOJI': ['📅', '⏰'],
76
+ 'BEHAVIOR': ['PUNCTUALITY', 'HONESTY'],
77
+ 'TONE': ['CONFIDENTIAL', 'SATIRICAL'],
78
+ 'LOCATION': ['CENTRAL PARK', 'DOWNTOWN LIBRARY']
79
+ }}
80
+
81
+ ################
82
+ Output:
83
+ {{
84
+ "answer_type_keywords": ["DATE AND TIME", "ORGANIZATION", "PLAN"],
85
+ "entities_from_query": ["SpaceX", "Rocket launch", "Aerospace", "Power Recovery"]
86
+
87
+ }}
88
+ #############################
89
+ Example 3:
90
+
91
+ Query: "What is the role of education in reducing poverty?"
92
+ Answer type pool: {{
93
+ 'PERSONAL LIFE': ['MANAGING WORK-LIFE BALANCE', 'HOME IMPROVEMENT PROJECTS'],
94
+ 'STRATEGY': ['MARKETING STRATEGIES FOR Q4', 'EXPANDING INTO NEW MARKETS'],
95
+ 'SERVICE FACILITATION': ['CUSTOMER SATISFACTION SURVEYS', 'STAFF RETENTION PROGRAMS'],
96
+ 'PERSON': ['ALBERT EINSTEIN', 'MARIA CALLAS'],
97
+ 'FOOD': ['PAN-FRIED STEAK', 'POACHED EGGS'],
98
+ 'EMOTION': ['OVERWHELM', 'CONTENTMENT'],
99
+ 'PERSONAL EXPERIENCE': ['LIVING ABROAD', 'STARTING A NEW JOB'],
100
+ 'INTERACTION': ['SOCIAL MEDIA ENGAGEMENT', 'PUBLIC SPEAKING'],
101
+ 'BEVERAGE': ['CAPPUCCINO', 'MATCHA LATTE'],
102
+ 'PLAN': ['ANNUAL FITNESS GOALS', 'QUARTERLY BUSINESS REVIEW'],
103
+ 'GEO': ['THE AMAZON RAINFOREST', 'THE GRAND CANYON'],
104
+ 'GEAR': ['SURFING ESSENTIALS', 'CYCLING ACCESSORIES'],
105
+ 'EMOJI': ['💻', '📱'],
106
+ 'BEHAVIOR': ['TEAMWORK', 'LEADERSHIP'],
107
+ 'TONE': ['FORMAL MEETING', 'CASUAL CONVERSATION'],
108
+ 'LOCATION': ['URBAN CITY CENTER', 'RURAL COUNTRYSIDE']
109
+ }}
110
+
111
+ ################
112
+ Output:
113
+ {{
114
+ "answer_type_keywords": ["STRATEGY", "PERSON"],
115
+ "entities_from_query": ["School access", "Literacy rates", "Job training", "Income inequality"]
116
+ }}
117
+ #############################
118
+ Example 4:
119
+
120
+ Query: "Where is the capital of the United States?"
121
+ Answer type pool: {{
122
+ 'ORGANIZATION': ['GREENPEACE', 'RED CROSS'],
123
+ 'PERSONAL LIFE': ['DAILY WORKOUT', 'HOME COOKING'],
124
+ 'STRATEGY': ['FINANCIAL INVESTMENT', 'BUSINESS EXPANSION'],
125
+ 'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
126
+ 'PERSON': ['ALBERTA SMITH', 'BENJAMIN JONES'],
127
+ 'FOOD': ['PASTA CARBONARA', 'SUSHI PLATTER'],
128
+ 'EMOTION': ['HAPPINESS', 'SADNESS'],
129
+ 'PERSONAL EXPERIENCE': ['TRAVEL ADVENTURE', 'BOOK CLUB'],
130
+ 'INTERACTION': ['TEAM BUILDING', 'NETWORKING MEETUP'],
131
+ 'BEVERAGE': ['LATTE', 'GREEN TEA'],
132
+ 'PLAN': ['WEIGHT LOSS', 'CAREER DEVELOPMENT'],
133
+ 'GEO': ['PARIS', 'NEW YORK'],
134
+ 'GEAR': ['CAMERA', 'HEADPHONES'],
135
+ 'EMOJI': ['🏢', '🌍'],
136
+ 'BEHAVIOR': ['POSITIVE THINKING', 'STRESS MANAGEMENT'],
137
+ 'TONE': ['FRIENDLY', 'PROFESSIONAL'],
138
+ 'LOCATION': ['DOWNTOWN', 'SUBURBS']
139
+ }}
140
+ ################
141
+ Output:
142
+ {{
143
+ "answer_type_keywords": ["LOCATION"],
144
+ "entities_from_query": ["capital of the United States", "Washington", "New York"]
145
+ }}
146
+ #############################
147
+
148
+ -Real Data-
149
+ ######################
150
+ Query: {query}
151
+ Answer type pool:{TYPE_POOL}
152
+ ######################
153
+ Output:
154
+
155
+ """
156
+
157
+ PROMPTS["keywords_extraction"] = """---Role---
158
+
159
+ You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
160
+
161
+ ---Goal---
162
+
163
+ Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
164
+
165
+ ---Instructions---
166
+
167
+ - Output the keywords in JSON format.
168
+ - The JSON should have two keys:
169
+ - "high_level_keywords" for overarching concepts or themes.
170
+ - "low_level_keywords" for specific entities or details.
171
+
172
+ ######################
173
+ -Examples-
174
+ ######################
175
+ {examples}
176
+
177
+ #############################
178
+ -Real Data-
179
+ ######################
180
+ Query: {query}
181
+ ######################
182
+ The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
183
+ Output:
184
+
185
+ """
186
+
187
+ PROMPTS["keywords_extraction_examples"] = [
188
+ """Example 1:
189
+
190
+ Query: "How does international trade influence global economic stability?"
191
+ ################
192
+ Output:
193
+ {
194
+ "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
195
+ "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
196
+ }
197
+ #############################""",
198
+ """Example 2:
199
+
200
+ Query: "What are the environmental consequences of deforestation on biodiversity?"
201
+ ################
202
+ Output:
203
+ {
204
+ "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
205
+ "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
206
+ }
207
+ #############################""",
208
+ """Example 3:
209
+
210
+ Query: "What is the role of education in reducing poverty?"
211
+ ################
212
+ Output:
213
+ {
214
+ "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
215
+ "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
216
+ }
217
+ #############################""",
218
+ ]
graphrag/search.py CHANGED
@@ -14,90 +14,313 @@
14
  # limitations under the License.
15
  #
16
  import json
 
 
17
  from copy import deepcopy
18
-
19
  import pandas as pd
20
- from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
21
 
22
- from rag.nlp.search import Dealer
 
 
 
 
 
 
23
 
24
 
25
  class KGSearch(Dealer):
26
- def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
27
- def merge_into_first(sres, title="") -> dict[str, str]:
28
- if not sres:
29
- return {}
30
- content_with_weight = ""
31
- df, texts = [],[]
32
- for d in sres.values():
33
- try:
34
- df.append(json.loads(d["content_with_weight"]))
35
- except Exception:
36
- texts.append(d["content_with_weight"])
37
- if df:
38
- content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
39
- else:
40
- content_with_weight = title + "\n" + "\n".join(texts)
41
- first_id = ""
42
- first_source = {}
43
- for k, v in sres.items():
44
- first_id = id
45
- first_source = deepcopy(v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  break
47
- first_source["content_with_weight"] = content_with_weight
48
- first_id = next(iter(sres))
49
- return {first_id: first_source}
50
-
51
- qst = req.get("question", "")
52
- matchText, keywords = self.qryr.question(qst, min_match=0.05)
53
- condition = self.get_filters(req)
54
-
55
- ## Entity retrieval
56
- condition.update({"knowledge_graph_kwd": ["entity"]})
57
- assert emb_mdl, "No embedding model selected"
58
- matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
59
- q_vec = matchDense.embedding_data
60
- src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
61
- "doc_id", f"q_{len(q_vec)}_vec", "position_int", "name_kwd",
62
- "available_int", "content_with_weight",
63
- "weight_int", "weight_flt"
64
- ])
65
-
66
- fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
67
-
68
- ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
69
- ent_res_fields = self.dataStore.getFields(ent_res, src)
70
- entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")]
71
- ent_ids = self.dataStore.getChunkIds(ent_res)
72
- ent_content = merge_into_first(ent_res_fields, "-Entities-")
73
- if ent_content:
74
- ent_ids = list(ent_content.keys())
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ## Community retrieval
77
- condition = self.get_filters(req)
78
- condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
79
- comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
80
- comm_res_fields = self.dataStore.getFields(comm_res, src)
81
- comm_ids = self.dataStore.getChunkIds(comm_res)
82
- comm_content = merge_into_first(comm_res_fields, "-Community Report-")
83
- if comm_content:
84
- comm_ids = list(comm_content.keys())
85
-
86
- ## Text content retrieval
87
- condition = self.get_filters(req)
88
- condition.update({"knowledge_graph_kwd": ["text"]})
89
- txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
90
- txt_res_fields = self.dataStore.getFields(txt_res, src)
91
- txt_ids = self.dataStore.getChunkIds(txt_res)
92
- txt_content = merge_into_first(txt_res_fields, "-Original Content-")
93
- if txt_content:
94
- txt_ids = list(txt_content.keys())
95
-
96
- return self.SearchResult(
97
- total=len(ent_ids) + len(comm_ids) + len(txt_ids),
98
- ids=[*ent_ids, *comm_ids, *txt_ids],
99
- query_vector=q_vec,
100
- highlight=None,
101
- field={**ent_content, **comm_content, **txt_content},
102
- keywords=[]
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
  import json
17
+ import logging
18
+ from collections import defaultdict
19
  from copy import deepcopy
20
+ import json_repair
21
  import pandas as pd
 
22
 
23
+ from api.utils import get_uuid
24
+ from graphrag.query_analyze_prompt import PROMPTS
25
+ from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation
26
+ from rag.utils import num_tokens_from_string
27
+ from rag.utils.doc_store_conn import OrderByExpr
28
+
29
+ from rag.nlp.search import Dealer, index_name
30
 
31
 
32
  class KGSearch(Dealer):
33
+ def _chat(self, llm_bdl, system, history, gen_conf):
34
+ response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
35
+ if response:
36
+ return response
37
+ response = llm_bdl.chat(system, history, gen_conf)
38
+ if response.find("**ERROR**") >= 0:
39
+ raise Exception(response)
40
+ set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
41
+ return response
42
+
43
+ def query_rewrite(self, llm, question, idxnms, kb_ids):
44
+ ty2ents = get_entity_type2sampels(idxnms, kb_ids)
45
+ hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
46
+ TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
47
+ result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5})
48
+ try:
49
+ keywords_data = json_repair.loads(result)
50
+ type_keywords = keywords_data.get("answer_type_keywords", [])
51
+ entities_from_query = keywords_data.get("entities_from_query", [])[:5]
52
+ return type_keywords, entities_from_query
53
+ except json_repair.JSONDecodeError:
54
+ try:
55
+ result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
56
+ result = '{' + result.split('{')[1].split('}')[0] + '}'
57
+ keywords_data = json_repair.loads(result)
58
+ type_keywords = keywords_data.get("answer_type_keywords", [])
59
+ entities_from_query = keywords_data.get("entities_from_query", [])[:5]
60
+ return type_keywords, entities_from_query
61
+ # Handle parsing error
62
+ except Exception as e:
63
+ logging.exception(f"JSON parsing error: {result} -> {e}")
64
+ raise e
65
+
66
+ def _ent_info_from_(self, es_res, sim_thr=0.3):
67
+ res = {}
68
+ es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt",
69
+ "n_hop_with_weight"])
70
+ for _, ent in es_res.items():
71
+ if float(ent.get("_score", 0)) < sim_thr:
72
+ continue
73
+ if isinstance(ent["entity_kwd"], list):
74
+ ent["entity_kwd"] = ent["entity_kwd"][0]
75
+ res[ent["entity_kwd"]] = {
76
+ "sim": float(ent.get("_score", 0)),
77
+ "pagerank": float(ent.get("rank_flt", 0)),
78
+ "n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
79
+ "description": ent.get("content_with_weight", "{}")
80
+ }
81
+ return res
82
+
83
+ def _relation_info_from_(self, es_res, sim_thr=0.3):
84
+ res = {}
85
+ es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
86
+ "weight_int"])
87
+ for _, ent in es_res.items():
88
+ if float(ent["_score"]) < sim_thr:
89
+ continue
90
+ f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
91
+ if isinstance(f, list):
92
+ f = f[0]
93
+ if isinstance(t, list):
94
+ t = t[0]
95
+ res[(f, t)] = {
96
+ "sim": float(ent["_score"]),
97
+ "pagerank": float(ent.get("weight_int", 0)),
98
+ "description": ent["content_with_weight"]
99
+ }
100
+ return res
101
+
102
+ def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
103
+ if not keywords:
104
+ return {}
105
+ filters = deepcopy(filters)
106
+ filters["knowledge_graph_kwd"] = "entity"
107
+ matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
108
+ es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense],
109
+ OrderByExpr(), 0, N,
110
+ idxnms, kb_ids)
111
+ return self._ent_info_from_(es_res, sim_thr)
112
+
113
+ def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
114
+ if not txt:
115
+ return {}
116
+ filters = deepcopy(filters)
117
+ filters["knowledge_graph_kwd"] = "relation"
118
+ matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
119
+ es_res = self.dataStore.search(
120
+ ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
121
+ [], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
122
+ return self._relation_info_from_(es_res, sim_thr)
123
+
124
+ def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
125
+ if not types:
126
+ return {}
127
+ filters = deepcopy(filters)
128
+ filters["knowledge_graph_kwd"] = "entity"
129
+ filters["entity_type_kwd"] = types
130
+ ordr = OrderByExpr()
131
+ ordr.desc("rank_flt")
132
+ es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
133
+ idxnms, kb_ids)
134
+ return self._ent_info_from_(es_res, 0)
135
+
136
+ def retrieval(self, question: str,
137
+ tenant_ids: str | list[str],
138
+ kb_ids: list[str],
139
+ emb_mdl,
140
+ llm,
141
+ max_token: int = 8196,
142
+ ent_topn: int = 6,
143
+ rel_topn: int = 6,
144
+ comm_topn: int = 1,
145
+ ent_sim_threshold: float = 0.3,
146
+ rel_sim_threshold: float = 0.3,
147
+ ):
148
+ qst = question
149
+ filters = self.get_filters({"kb_ids": kb_ids})
150
+ if isinstance(tenant_ids, str):
151
+ tenant_ids = tenant_ids.split(",")
152
+ idxnms = [index_name(tid) for tid in tenant_ids]
153
+ ty_kwds = []
154
+ ents = []
155
+ try:
156
+ ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
157
+ logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
158
+ except Exception as e:
159
+ logging.exception(e)
160
+ ents = [qst]
161
+ pass
162
+
163
+ ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
164
+ ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
165
+ rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
166
+ nhop_pathes = defaultdict(dict)
167
+ for _, ent in ents_from_query.items():
168
+ nhops = ent.get("n_hop_ents", [])
169
+ for nbr in nhops:
170
+ path = nbr["path"]
171
+ wts = nbr["weights"]
172
+ for i in range(len(path) - 1):
173
+ f, t = path[i], path[i + 1]
174
+ if (f, t) in nhop_pathes:
175
+ nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
176
+ else:
177
+ nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
178
+ nhop_pathes[(f, t)]["pagerank"] = wts[i]
179
+
180
+ logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
181
+ logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
182
+ logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
183
+ logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
184
+
185
+ # P(E|Q) => P(E) * P(Q|E) => pagerank * sim
186
+ for ent in ents_from_types.keys():
187
+ if ent not in ents_from_query:
188
+ continue
189
+ ents_from_query[ent]["sim"] *= 2
190
+
191
+ for (f, t) in rels_from_txt.keys():
192
+ pair = tuple(sorted([f, t]))
193
+ s = 0
194
+ if pair in nhop_pathes:
195
+ s += nhop_pathes[pair]["sim"]
196
+ del nhop_pathes[pair]
197
+ if f in ents_from_types:
198
+ s += 1
199
+ if t in ents_from_types:
200
+ s += 1
201
+ rels_from_txt[(f, t)]["sim"] *= s + 1
202
+
203
+ # This is for the relations from n-hop but not by query search
204
+ for (f, t) in nhop_pathes.keys():
205
+ s = 0
206
+ if f in ents_from_types:
207
+ s += 1
208
+ if t in ents_from_types:
209
+ s += 1
210
+ rels_from_txt[(f, t)] = {
211
+ "sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
212
+ "pagerank": nhop_pathes[(f, t)]["pagerank"]
213
+ }
214
+
215
+ ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
216
+ :ent_topn]
217
+ rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
218
+ :rel_topn]
219
+
220
+ ents = []
221
+ relas = []
222
+ for n, ent in ents_from_query:
223
+ ents.append({
224
+ "Entity": n,
225
+ "Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
226
+ "Description": json.loads(ent["description"]).get("description", "")
227
+ })
228
+ max_token -= num_tokens_from_string(str(ents[-1]))
229
+ if max_token <= 0:
230
+ ents = ents[:-1]
231
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ for (f, t), rel in rels_from_txt:
234
+ if not rel.get("description"):
235
+ for tid in tenant_ids:
236
+ rela = get_relation(tid, kb_ids, f, t)
237
+ if rela:
238
+ break
239
+ else:
240
+ continue
241
+ rel["description"] = rela["description"]
242
+ relas.append({
243
+ "From Entity": f,
244
+ "To Entity": t,
245
+ "Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
246
+ "Description": json.loads(ent["description"]).get("description", "")
247
+ })
248
+ max_token -= num_tokens_from_string(str(relas[-1]))
249
+ if max_token <= 0:
250
+ relas = relas[:-1]
251
+ break
252
+
253
+ if ents:
254
+ ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
255
+ else:
256
+ ents = ""
257
+ if relas:
258
+ relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
259
+ else:
260
+ relas = ""
261
+
262
+ return {
263
+ "chunk_id": get_uuid(),
264
+ "content_ltks": "",
265
+ "content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
266
+ comm_topn, max_token),
267
+ "doc_id": "",
268
+ "docnm_kwd": "Related content in Knowledge Graph",
269
+ "kb_id": kb_ids,
270
+ "important_kwd": [],
271
+ "image_id": "",
272
+ "similarity": 1.,
273
+ "vector_similarity": 1.,
274
+ "term_similarity": 0,
275
+ "vector": [],
276
+ "positions": [],
277
+ }
278
+
279
+ def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token):
280
  ## Community retrieval
281
+ fields = ["docnm_kwd", "content_with_weight"]
282
+ odr = OrderByExpr()
283
+ odr.desc("weight_flt")
284
+ fltr = deepcopy(condition)
285
+ fltr["knowledge_graph_kwd"] = "community_report"
286
+ fltr["entities_kwd"] = entities
287
+ comm_res = self.dataStore.search(fields, [], fltr, [],
288
+ OrderByExpr(), 0, topn, idxnms, kb_ids)
289
+ comm_res_fields = self.dataStore.getFields(comm_res, fields)
290
+ txts = []
291
+ for ii, (_, row) in enumerate(comm_res_fields.items()):
292
+ obj = json.loads(row["content_with_weight"])
293
+ txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
294
+ ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
295
+ max_token -= num_tokens_from_string(str(txts[-1]))
296
+
297
+ if not txts:
298
+ return ""
299
+ return "\n-Community Report-\n" + "\n".join(txts)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ from api import settings
304
+ import argparse
305
+ from api.db import LLMType
306
+ from api.db.services.knowledgebase_service import KnowledgebaseService
307
+ from api.db.services.llm_service import LLMBundle
308
+ from api.db.services.user_service import TenantService
309
+ from rag.nlp import search
310
+
311
+ settings.init_settings()
312
+ parser = argparse.ArgumentParser()
313
+ parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
314
+ parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
315
+ parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
316
+ args = parser.parse_args()
317
+
318
+ kb_id = args.kb_id
319
+ _, tenant = TenantService.get_by_id(args.tenant_id)
320
+ llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
321
+ _, kb = KnowledgebaseService.get_by_id(kb_id)
322
+ embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
323
+
324
+ kg = KGSearch(settings.docStoreConn)
325
+ print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
326
+ search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
graphrag/utils.py CHANGED
@@ -3,16 +3,26 @@
3
  """
4
  Reference:
5
  - [graphrag](https://github.com/microsoft/graphrag)
 
6
  """
7
 
8
  import html
9
  import json
 
10
  import re
 
 
 
 
11
  from typing import Any, Callable
12
 
 
13
  import numpy as np
14
  import xxhash
 
15
 
 
 
16
  from rag.utils.redis_conn import REDIS_CONN
17
 
18
  ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
@@ -131,3 +141,379 @@ def set_tags_to_cache(kb_ids, tags):
131
 
132
  k = hasher.hexdigest()
133
  REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
  Reference:
5
  - [graphrag](https://github.com/microsoft/graphrag)
6
+ - [LightRag](https://github.com/HKUDS/LightRAG)
7
  """
8
 
9
  import html
10
  import json
11
+ import logging
12
  import re
13
+ import time
14
+ from collections import defaultdict
15
+ from copy import deepcopy
16
+ from hashlib import md5
17
  from typing import Any, Callable
18
 
19
+ import networkx as nx
20
  import numpy as np
21
  import xxhash
22
+ from networkx.readwrite import json_graph
23
 
24
+ from api import settings
25
+ from rag.nlp import search, rag_tokenizer
26
  from rag.utils.redis_conn import REDIS_CONN
27
 
28
  ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
 
141
 
142
  k = hasher.hexdigest()
143
  REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
144
+
145
+
146
+ def graph_merge(g1, g2):
147
+ g = g2.copy()
148
+ for n, attr in g1.nodes(data=True):
149
+ if n not in g2.nodes():
150
+ g.add_node(n, **attr)
151
+ continue
152
+
153
+ for source, target, attr in g1.edges(data=True):
154
+ if g.has_edge(source, target):
155
+ g[source][target].update({"weight": attr.get("weight", 0)+1})
156
+ continue
157
+ g.add_edge(source, target)#, **attr)
158
+
159
+ for node_degree in g.degree:
160
+ g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
161
+ return g
162
+
163
+
164
+ def compute_args_hash(*args):
165
+ return md5(str(args).encode()).hexdigest()
166
+
167
+
168
+ def handle_single_entity_extraction(
169
+ record_attributes: list[str],
170
+ chunk_key: str,
171
+ ):
172
+ if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
173
+ return None
174
+ # add this record as a node in the G
175
+ entity_name = clean_str(record_attributes[1].upper())
176
+ if not entity_name.strip():
177
+ return None
178
+ entity_type = clean_str(record_attributes[2].upper())
179
+ entity_description = clean_str(record_attributes[3])
180
+ entity_source_id = chunk_key
181
+ return dict(
182
+ entity_name=entity_name.upper(),
183
+ entity_type=entity_type.upper(),
184
+ description=entity_description,
185
+ source_id=entity_source_id,
186
+ )
187
+
188
+
189
+ def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
190
+ if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
191
+ return None
192
+ # add this record as edge
193
+ source = clean_str(record_attributes[1].upper())
194
+ target = clean_str(record_attributes[2].upper())
195
+ edge_description = clean_str(record_attributes[3])
196
+
197
+ edge_keywords = clean_str(record_attributes[4])
198
+ edge_source_id = chunk_key
199
+ weight = (
200
+ float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
201
+ )
202
+ pair = sorted([source.upper(), target.upper()])
203
+ return dict(
204
+ src_id=pair[0],
205
+ tgt_id=pair[1],
206
+ weight=weight,
207
+ description=edge_description,
208
+ keywords=edge_keywords,
209
+ source_id=edge_source_id,
210
+ metadata={"created_at": time.time()},
211
+ )
212
+
213
+
214
+ def pack_user_ass_to_openai_messages(*args: str):
215
+ roles = ["user", "assistant"]
216
+ return [
217
+ {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
218
+ ]
219
+
220
+
221
+ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
222
+ """Split a string by multiple markers"""
223
+ if not markers:
224
+ return [content]
225
+ results = re.split("|".join(re.escape(marker) for marker in markers), content)
226
+ return [r.strip() for r in results if r.strip()]
227
+
228
+
229
+ def is_float_regex(value):
230
+ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
231
+
232
+
233
+ def chunk_id(chunk):
234
+ return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
235
+
236
+
237
+ def get_entity(tenant_id, kb_id, ent_name):
238
+ conds = {
239
+ "fields": ["content_with_weight"],
240
+ "entity_kwd": ent_name,
241
+ "size": 10000,
242
+ "knowledge_graph_kwd": ["entity"]
243
+ }
244
+ res = []
245
+ es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
246
+ for id in es_res.ids:
247
+ try:
248
+ if isinstance(ent_name, str):
249
+ return json.loads(es_res.field[id]["content_with_weight"])
250
+ res.append(json.loads(es_res.field[id]["content_with_weight"]))
251
+ except Exception:
252
+ continue
253
+
254
+ return res
255
+
256
+
257
+ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
258
+ chunk = {
259
+ "important_kwd": [ent_name],
260
+ "title_tks": rag_tokenizer.tokenize(ent_name),
261
+ "entity_kwd": ent_name,
262
+ "knowledge_graph_kwd": "entity",
263
+ "entity_type_kwd": meta["entity_type"],
264
+ "content_with_weight": json.dumps(meta, ensure_ascii=False),
265
+ "content_ltks": rag_tokenizer.tokenize(meta["description"]),
266
+ "source_id": list(set(meta["source_id"])),
267
+ "kb_id": kb_id,
268
+ "available_int": 0
269
+ }
270
+ chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
271
+ res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
272
+ search.index_name(tenant_id), [kb_id])
273
+ if res.ids:
274
+ settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
275
+ else:
276
+ ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
277
+ if ebd is None:
278
+ try:
279
+ ebd, _ = embd_mdl.encode([ent_name])
280
+ ebd = ebd[0]
281
+ set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
282
+ except Exception as e:
283
+ logging.exception(f"Fail to embed entity: {e}")
284
+ if ebd is not None:
285
+ chunk["q_%d_vec" % len(ebd)] = ebd
286
+ settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
287
+
288
+
289
+ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
290
+ ents = from_ent_name
291
+ if isinstance(ents, str):
292
+ ents = [from_ent_name]
293
+ if isinstance(to_ent_name, str):
294
+ to_ent_name = [to_ent_name]
295
+ ents.extend(to_ent_name)
296
+ ents = list(set(ents))
297
+ conds = {
298
+ "fields": ["content_with_weight"],
299
+ "size": size,
300
+ "from_entity_kwd": ents,
301
+ "to_entity_kwd": ents,
302
+ "knowledge_graph_kwd": ["relation"]
303
+ }
304
+ res = []
305
+ es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
306
+ for id in es_res.ids:
307
+ try:
308
+ if size == 1:
309
+ return json.loads(es_res.field[id]["content_with_weight"])
310
+ res.append(json.loads(es_res.field[id]["content_with_weight"]))
311
+ except Exception:
312
+ continue
313
+ return res
314
+
315
+
316
+ def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
317
+ chunk = {
318
+ "from_entity_kwd": from_ent_name,
319
+ "to_entity_kwd": to_ent_name,
320
+ "knowledge_graph_kwd": "relation",
321
+ "content_with_weight": json.dumps(meta, ensure_ascii=False),
322
+ "content_ltks": rag_tokenizer.tokenize(meta["description"]),
323
+ "important_kwd": meta["keywords"],
324
+ "source_id": list(set(meta["source_id"])),
325
+ "weight_int": int(meta["weight"]),
326
+ "kb_id": kb_id,
327
+ "available_int": 0
328
+ }
329
+ chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
330
+ res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
331
+ search.index_name(tenant_id), [kb_id])
332
+
333
+ if res.ids:
334
+ settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
335
+ chunk,
336
+ search.index_name(tenant_id), kb_id)
337
+ else:
338
+ txt = f"{from_ent_name}->{to_ent_name}"
339
+ ebd = get_embed_cache(embd_mdl.llm_name, txt)
340
+ if ebd is None:
341
+ try:
342
+ ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
343
+ ebd = ebd[0]
344
+ set_embed_cache(embd_mdl.llm_name, txt, ebd)
345
+ except Exception as e:
346
+ logging.exception(f"Fail to embed entity relation: {e}")
347
+ if ebd is not None:
348
+ chunk["q_%d_vec" % len(ebd)] = ebd
349
+ settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
350
+
351
+
352
+ def get_graph(tenant_id, kb_id):
353
+ conds = {
354
+ "fields": ["content_with_weight", "source_id"],
355
+ "removed_kwd": "N",
356
+ "size": 1,
357
+ "knowledge_graph_kwd": ["graph"]
358
+ }
359
+ res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
360
+ for id in res.ids:
361
+ try:
362
+ return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
363
+ res.field[id]["source_id"]
364
+ except Exception:
365
+ continue
366
+ return None, None
367
+
368
+
369
+ def set_graph(tenant_id, kb_id, graph, docids):
370
+ chunk = {
371
+ "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
372
+ indent=2),
373
+ "knowledge_graph_kwd": "graph",
374
+ "kb_id": kb_id,
375
+ "source_id": list(docids),
376
+ "available_int": 0,
377
+ "removed_kwd": "N"
378
+ }
379
+ res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
380
+ if res.ids:
381
+ settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
382
+ search.index_name(tenant_id), kb_id)
383
+ else:
384
+ settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
385
+
386
+
387
+ def is_continuous_subsequence(subseq, seq):
388
+ def find_all_indexes(tup, value):
389
+ indexes = []
390
+ start = 0
391
+ while True:
392
+ try:
393
+ index = tup.index(value, start)
394
+ indexes.append(index)
395
+ start = index + 1
396
+ except ValueError:
397
+ break
398
+ return indexes
399
+
400
+ index_list = find_all_indexes(seq,subseq[0])
401
+ for idx in index_list:
402
+ if idx!=len(seq)-1:
403
+ if seq[idx+1]==subseq[-1]:
404
+ return True
405
+ return False
406
+
407
+
408
+ def merge_tuples(list1, list2):
409
+ result = []
410
+ for tup in list1:
411
+ last_element = tup[-1]
412
+ if last_element in tup[:-1]:
413
+ result.append(tup)
414
+ else:
415
+ matching_tuples = [t for t in list2 if t[0] == last_element]
416
+ already_match_flag = 0
417
+ for match in matching_tuples:
418
+ matchh = (match[1], match[0])
419
+ if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
420
+ continue
421
+ already_match_flag = 1
422
+ merged_tuple = tup + match[1:]
423
+ result.append(merged_tuple)
424
+ if not already_match_flag:
425
+ result.append(tup)
426
+ return result
427
+
428
+
429
+ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
430
+ def n_neighbor(id):
431
+ nonlocal graph, n_hop
432
+ count = 0
433
+ source_edge = list(graph.edges(id))
434
+ if not source_edge:
435
+ return []
436
+ count = count + 1
437
+ while count < n_hop:
438
+ count = count + 1
439
+ sc_edge = deepcopy(source_edge)
440
+ source_edge = []
441
+ for pair in sc_edge:
442
+ append_edge = list(graph.edges(pair[-1]))
443
+ for tuples in merge_tuples([pair], append_edge):
444
+ source_edge.append(tuples)
445
+ nbrs = []
446
+ for path in source_edge:
447
+ n = {"path": path, "weights": []}
448
+ wts = nx.get_edge_attributes(graph, 'weight')
449
+ for i in range(len(path)-1):
450
+ f, t = path[i], path[i+1]
451
+ n["weights"].append(wts.get((f, t), 0))
452
+ nbrs.append(n)
453
+ return nbrs
454
+
455
+ pr = nx.pagerank(graph)
456
+ for n, p in pr.items():
457
+ graph.nodes[n]["pagerank"] = p
458
+ try:
459
+ settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
460
+ {"rank_flt": p,
461
+ "n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
462
+ search.index_name(tenant_id), kb_id)
463
+ except Exception as e:
464
+ logging.exception(e)
465
+
466
+ ty2ents = defaultdict(list)
467
+ for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
468
+ ty = graph.nodes[p].get("entity_type")
469
+ if not ty or len(ty2ents[ty]) > 12:
470
+ continue
471
+ ty2ents[ty].append(p)
472
+
473
+ chunk = {
474
+ "content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
475
+ "kb_id": kb_id,
476
+ "knowledge_graph_kwd": "ty2ents",
477
+ "available_int": 0
478
+ }
479
+ res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
480
+ search.index_name(tenant_id), [kb_id])
481
+ if res.ids:
482
+ settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
483
+ chunk,
484
+ search.index_name(tenant_id), kb_id)
485
+ else:
486
+ settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
487
+
488
+
489
+ def get_entity_type2sampels(idxnms, kb_ids: list):
490
+ es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
491
+ "size": 10000,
492
+ "fields": ["content_with_weight"]},
493
+ idxnms, kb_ids)
494
+
495
+ res = defaultdict(list)
496
+ for id in es_res.ids:
497
+ smp = es_res.field[id].get("content_with_weight")
498
+ if not smp:
499
+ continue
500
+ try:
501
+ smp = json.loads(smp)
502
+ except Exception as e:
503
+ logging.exception(e)
504
+
505
+ for ty, ents in smp.items():
506
+ res[ty].extend(ents)
507
+ return res
508
+
509
+
510
+ def flat_uniq_list(arr, key):
511
+ res = []
512
+ for a in arr:
513
+ a = a[key]
514
+ if isinstance(a, list):
515
+ res.extend(a)
516
+ else:
517
+ res.append(a)
518
+ return list(set(res))
519
+
pyproject.toml CHANGED
@@ -51,6 +51,7 @@ dependencies = [
51
  "infinity-sdk==0.6.0-dev2",
52
  "infinity-emb>=0.0.66,<0.0.67",
53
  "itsdangerous==2.1.2",
 
54
  "markdown==3.6",
55
  "markdown-to-json==2.1.1",
56
  "minio==7.2.4",
@@ -130,4 +131,4 @@ full = [
130
  "flagembedding==1.2.10",
131
  "torch==2.3.0",
132
  "transformers==4.38.1"
133
- ]
 
51
  "infinity-sdk==0.6.0-dev2",
52
  "infinity-emb>=0.0.66,<0.0.67",
53
  "itsdangerous==2.1.2",
54
+ "json-repair==0.35.0",
55
  "markdown==3.6",
56
  "markdown-to-json==2.1.1",
57
  "minio==7.2.4",
 
131
  "flagembedding==1.2.10",
132
  "torch==2.3.0",
133
  "transformers==4.38.1"
134
+ ]
rag/app/book.py CHANGED
@@ -88,9 +88,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
88
  callback(0.8, "Finish parsing.")
89
 
90
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
91
- pdf_parser = Pdf() if kwargs.get(
92
- "parser_config", {}).get(
93
- "layout_recognize", True) else PlainParser()
94
  sections, tbls = pdf_parser(filename if not binary else binary,
95
  from_page=from_page, to_page=to_page, callback=callback)
96
 
 
88
  callback(0.8, "Finish parsing.")
89
 
90
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
91
+ pdf_parser = Pdf()
92
+ if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
93
+ pdf_parser = PlainParser()
94
  sections, tbls = pdf_parser(filename if not binary else binary,
95
  from_page=from_page, to_page=to_page, callback=callback)
96
 
rag/app/email.py CHANGED
@@ -40,7 +40,7 @@ def chunk(
40
  eng = lang.lower() == "english" # is_english(cks)
41
  parser_config = kwargs.get(
42
  "parser_config",
43
- {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True},
44
  )
45
  doc = {
46
  "docnm_kwd": filename,
 
40
  eng = lang.lower() == "english" # is_english(cks)
41
  parser_config = kwargs.get(
42
  "parser_config",
43
+ {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
44
  )
45
  doc = {
46
  "docnm_kwd": filename,
rag/app/knowledge_graph.py DELETED
@@ -1,48 +0,0 @@
1
- #
2
- # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- import re
18
-
19
- from graphrag.index import build_knowledge_graph_chunks
20
- from rag.app import naive
21
- from rag.nlp import rag_tokenizer, tokenize_chunks
22
-
23
-
24
- def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
25
- lang="Chinese", callback=None, **kwargs):
26
- parser_config = kwargs.get(
27
- "parser_config", {
28
- "chunk_token_num": 512, "delimiter": "\n!?;。;!?", "layout_recognize": True})
29
- eng = lang.lower() == "english"
30
-
31
- parser_config["layout_recognize"] = True
32
- sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True,
33
- parser_config=parser_config, callback=callback)
34
- chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
35
- parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
36
- )
37
- for c in chunks:
38
- c["docnm_kwd"] = filename
39
-
40
- doc = {
41
- "docnm_kwd": filename,
42
- "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
43
- "knowledge_graph_kwd": "text"
44
- }
45
- doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
46
- chunks.extend(tokenize_chunks(sections, doc, eng))
47
-
48
- return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag/app/laws.py CHANGED
@@ -162,9 +162,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
162
  return tokenize_chunks(chunks, doc, eng, None)
163
 
164
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
165
- pdf_parser = Pdf() if kwargs.get(
166
- "parser_config", {}).get(
167
- "layout_recognize", True) else PlainParser()
168
  for txt, poss in pdf_parser(filename if not binary else binary,
169
  from_page=from_page, to_page=to_page, callback=callback)[0]:
170
  sections.append(txt + poss)
 
162
  return tokenize_chunks(chunks, doc, eng, None)
163
 
164
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
165
+ pdf_parser = Pdf()
166
+ if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
167
+ pdf_parser = PlainParser()
168
  for txt, poss in pdf_parser(filename if not binary else binary,
169
  from_page=from_page, to_page=to_page, callback=callback)[0]:
170
  sections.append(txt + poss)
rag/app/manual.py CHANGED
@@ -184,9 +184,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
184
  # is it English
185
  eng = lang.lower() == "english" # pdf_parser.is_english
186
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
187
- pdf_parser = Pdf() if kwargs.get(
188
- "parser_config", {}).get(
189
- "layout_recognize", True) else PlainParser()
190
  sections, tbls = pdf_parser(filename if not binary else binary,
191
  from_page=from_page, to_page=to_page, callback=callback)
192
  if sections and len(sections[0]) < 3:
 
184
  # is it English
185
  eng = lang.lower() == "english" # pdf_parser.is_english
186
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
187
+ pdf_parser = Pdf()
188
+ if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
189
+ pdf_parser = PlainParser()
190
  sections, tbls = pdf_parser(filename if not binary else binary,
191
  from_page=from_page, to_page=to_page, callback=callback)
192
  if sections and len(sections[0]) < 3:
rag/app/naive.py CHANGED
@@ -202,7 +202,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
202
  is_english = lang.lower() == "english" # is_english(cks)
203
  parser_config = kwargs.get(
204
  "parser_config", {
205
- "chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
206
  doc = {
207
  "docnm_kwd": filename,
208
  "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
@@ -231,8 +231,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
231
  return res
232
 
233
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
234
- pdf_parser = Pdf() if parser_config.get("layout_recognize", True) else PlainParser()
235
- sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
 
 
 
236
  res = tokenize_table(tables, doc, is_english)
237
 
238
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
 
202
  is_english = lang.lower() == "english" # is_english(cks)
203
  parser_config = kwargs.get(
204
  "parser_config", {
205
+ "chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
206
  doc = {
207
  "docnm_kwd": filename,
208
  "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
231
  return res
232
 
233
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
234
+ pdf_parser = Pdf()
235
+ if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
236
+ pdf_parser = PlainParser()
237
+ sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
238
+ callback=callback)
239
  res = tokenize_table(tables, doc, is_english)
240
 
241
  elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
rag/app/one.py CHANGED
@@ -84,9 +84,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
84
  callback(0.8, "Finish parsing.")
85
 
86
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
87
- pdf_parser = Pdf() if kwargs.get(
88
- "parser_config", {}).get(
89
- "layout_recognize", True) else PlainParser()
90
  sections, _ = pdf_parser(
91
  filename if not binary else binary, to_page=to_page, callback=callback)
92
  sections = [s for s, _ in sections if s]
 
84
  callback(0.8, "Finish parsing.")
85
 
86
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
87
+ pdf_parser = Pdf()
88
+ if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
89
+ pdf_parser = PlainParser()
90
  sections, _ = pdf_parser(
91
  filename if not binary else binary, to_page=to_page, callback=callback)
92
  sections = [s for s, _ in sections if s]
rag/app/paper.py CHANGED
@@ -144,7 +144,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
144
  The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
145
  """
146
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
147
- if not kwargs.get("parser_config", {}).get("layout_recognize", True):
148
  pdf_parser = PlainParser()
149
  paper = {
150
  "title": filename,
 
144
  The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
145
  """
146
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
147
+ if kwargs.get("parser_config", {}).get("layout_recognize", "DeepDOC") == "Plain Text":
148
  pdf_parser = PlainParser()
149
  paper = {
150
  "title": filename,
rag/app/presentation.py CHANGED
@@ -119,9 +119,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
119
  res.append(d)
120
  return res
121
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
122
- pdf_parser = Pdf() if kwargs.get(
123
- "parser_config", {}).get(
124
- "layout_recognize", True) else PlainPdf()
125
  for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
126
  from_page=from_page, to_page=to_page, callback=callback)):
127
  d = copy.deepcopy(doc)
 
119
  res.append(d)
120
  return res
121
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
122
+ pdf_parser = Pdf()
123
+ if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
124
+ pdf_parser = PlainParser()
125
  for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
126
  from_page=from_page, to_page=to_page, callback=callback)):
127
  d = copy.deepcopy(doc)
rag/llm/chat_model.py CHANGED
@@ -32,6 +32,7 @@ import asyncio
32
  LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
33
  LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
34
 
 
35
  class Base(ABC):
36
  def __init__(self, key, model_name, base_url):
37
  timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
 
32
  LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
33
  LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
34
 
35
+
36
  class Base(ABC):
37
  def __init__(self, key, model_name, base_url):
38
  timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
rag/nlp/search.py CHANGED
@@ -59,7 +59,7 @@ class Dealer:
59
  if key in req and req[key] is not None:
60
  condition[field] = req[key]
61
  # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
62
- for key in ["knowledge_graph_kwd", "available_int"]:
63
  if key in req and req[key] is not None:
64
  condition[key] = req[key]
65
  return condition
@@ -198,6 +198,11 @@ class Dealer:
198
  return answer, set([])
199
 
200
  ans_v, _ = embd_mdl.encode(pieces_)
 
 
 
 
 
201
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
202
  len(ans_v[0]), len(chunk_v[0]))
203
 
@@ -434,6 +439,8 @@ class Dealer:
434
  es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
435
  kb_ids)
436
  dict_chunks = self.dataStore.getFields(es_res, fields)
 
 
437
  if dict_chunks:
438
  res.extend(dict_chunks.values())
439
  if len(dict_chunks.values()) < bs:
 
59
  if key in req and req[key] is not None:
60
  condition[field] = req[key]
61
  # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
62
+ for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
63
  if key in req and req[key] is not None:
64
  condition[key] = req[key]
65
  return condition
 
198
  return answer, set([])
199
 
200
  ans_v, _ = embd_mdl.encode(pieces_)
201
+ for i in range(len(chunk_v)):
202
+ if len(ans_v[0]) != len(chunk_v[i]):
203
+ chunk_v[i] = [0.0]*len(ans_v[0])
204
+ logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
205
+
206
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
207
  len(ans_v[0]), len(chunk_v[0]))
208
 
 
439
  es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
440
  kb_ids)
441
  dict_chunks = self.dataStore.getFields(es_res, fields)
442
+ for id, doc in dict_chunks.items():
443
+ doc["id"] = id
444
  if dict_chunks:
445
  res.extend(dict_chunks.values())
446
  if len(dict_chunks.values()) < bs: