hlydecker's picture
Duplicate from hlydecker/Augmented-Retrieval-qa-ChatGPT
1ce95c4
raw
history blame
730 Bytes
from typing import List
from langchain.indexes.graph import *
from langchain.indexes.graph import GraphIndexCreator as OriginalGraphIndexCreator
class GraphIndexCreator(OriginalGraphIndexCreator):
def from_texts(self, texts: List[str]) -> NetworkxEntityGraph:
"""Create graph index from text."""
if self.llm is None:
raise ValueError("llm should not be None")
graph = self.graph_type()
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
for text in texts:
output = chain.predict(text=text)
knowledge = parse_triples(output)
for triple in knowledge:
graph.add_triple(triple)
return graph