Spaces:
Build error
Build error
File size: 4,330 Bytes
70303d6 efad059 70303d6 cabb7e3 ba0e651 70303d6 efad059 70303d6 efad059 70303d6 456234e 70303d6 456234e 4c871d1 456234e 70303d6 efad059 70303d6 4c871d1 70303d6 33e0532 ba0e651 efad059 70303d6 efad059 70303d6 efad059 70303d6 efad059 70303d6 efad059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
def generate_knowledge_graph(texts: List[str], filename: str):
nlp = spacy.load("en_core_web_sm")
doc = nlp("\n".join(texts).lower())
NERs = [ent.text for ent in doc.ents]
NER_types = [ent.label_ for ent in doc.ents]
triplets = []
for triplet in texts:
triplets.extend(generate_partial_graph(triplet))
heads = [t["head"].lower() for t in triplets]
tails = [t["tail"].lower() for t in triplets]
nodes = list(set(heads + tails))
net = Network(directed=True, width="700px", height="700px")
for n in nodes:
if n in NERs:
NER_type = NER_types[NERs.index(n)]
if NER_type in NER_types:
if NER_type in DEFAULT_LABEL_COLORS.keys():
color = DEFAULT_LABEL_COLORS[NER_type]
else:
color = "#666666"
net.add_node(n, title=NER_type, shape="circle", color=color)
else:
net.add_node(n, shape="circle")
else:
net.add_node(n, shape="circle")
unique_triplets = set()
def stringify_trip(x): return x["tail"] + x["head"] + x["type"].lower()
for triplet in triplets:
if stringify_trip(triplet) not in unique_triplets:
net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
title=triplet["type"], label=triplet["type"])
unique_triplets.add(stringify_trip(triplet))
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
return nodes
@lru_cache(maxsize=16)
def generate_partial_graph(text: str):
triplet_extractor = pipeline(
'text2text-generation',
model='Babelscape/rebel-large',
tokenizer='Babelscape/rebel-large'
)
triples = triplet_extractor(
text,
return_tensors=True,
return_text=False)
if len(triples) == 0:
return []
a = [triples[0]["generated_token_ids"]]
extracted_text = triplet_extractor.tokenizer.batch_decode(a)
extracted_triplets = extract_triplets(extracted_text[0])
return extracted_triplets
def extract_triplets(text):
"""
Function to parse the generated text and extract the triplets
"""
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
return triplets
if __name__ == "__main__":
generate_knowledge_graph(
["The dog is happy", "The cat is sad"], "test.html")
|