File size: 4,330 Bytes
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efad059
70303d6
cabb7e3
ba0e651
70303d6
efad059
70303d6
 
 
 
efad059
 
70303d6
456234e
 
70303d6
 
 
 
456234e
4c871d1
 
 
 
456234e
 
 
70303d6
 
 
 
efad059
70303d6
 
4c871d1
 
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
33e0532
ba0e651
efad059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efad059
 
70303d6
 
 
 
 
efad059
 
70303d6
 
 
 
 
 
 
 
 
 
 
 
efad059
 
70303d6
 
 
efad059
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy


DEFAULT_LABEL_COLORS = {
    "ORG": "#7aecec",
    "PRODUCT": "#bfeeb7",
    "GPE": "#feca74",
    "LOC": "#ff9561",
    "PERSON": "#aa9cfc",
    "NORP": "#c887fb",
    "FACILITY": "#9cc9cc",
    "EVENT": "#ffeb80",
    "LAW": "#ff8197",
    "LANGUAGE": "#ff8197",
    "WORK_OF_ART": "#f0d0ff",
    "DATE": "#bfe1d9",
    "TIME": "#bfe1d9",
    "MONEY": "#e4e7d2",
    "QUANTITY": "#e4e7d2",
    "ORDINAL": "#e4e7d2",
    "CARDINAL": "#e4e7d2",
    "PERCENT": "#e4e7d2",
}


def generate_knowledge_graph(texts: List[str], filename: str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp("\n".join(texts).lower())
    NERs = [ent.text for ent in doc.ents]
    NER_types = [ent.label_ for ent in doc.ents]

    triplets = []
    for triplet in texts:
        triplets.extend(generate_partial_graph(triplet))
    heads = [t["head"].lower() for t in triplets]
    tails = [t["tail"].lower() for t in triplets]

    nodes = list(set(heads + tails))
    net = Network(directed=True, width="700px", height="700px")

    for n in nodes:
        if n in NERs:
            NER_type = NER_types[NERs.index(n)]
            if NER_type in NER_types:
                if NER_type in DEFAULT_LABEL_COLORS.keys():
                    color = DEFAULT_LABEL_COLORS[NER_type]
                else:
                    color = "#666666"
                net.add_node(n, title=NER_type, shape="circle", color=color)
            else:
                net.add_node(n, shape="circle")
        else:
            net.add_node(n, shape="circle")

    unique_triplets = set()
    def stringify_trip(x): return x["tail"] + x["head"] + x["type"].lower()
    for triplet in triplets:
        if stringify_trip(triplet) not in unique_triplets:
            net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
                         title=triplet["type"], label=triplet["type"])
            unique_triplets.add(stringify_trip(triplet))

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)
    return nodes


@lru_cache(maxsize=16)
def generate_partial_graph(text: str):
    triplet_extractor = pipeline(
        'text2text-generation',
        model='Babelscape/rebel-large',
        tokenizer='Babelscape/rebel-large'
    )

    triples = triplet_extractor(
        text,
        return_tensors=True,
        return_text=False)

    if len(triples) == 0:
        return []

    a = [triples[0]["generated_token_ids"]]

    extracted_text = triplet_extractor.tokenizer.batch_decode(a)
    extracted_triplets = extract_triplets(extracted_text[0])
    return extracted_triplets


def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append(
                    {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append(
                    {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append(
            {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})

    return triplets


if __name__ == "__main__":
    generate_knowledge_graph(
        ["The dog is happy", "The cat is sad"], "test.html")