File size: 2,123 Bytes
b77ef40
 
 
 
 
4eb245b
b77ef40
 
8fd6eff
b77ef40
 
 
 
 
8fd6eff
 
abd18ef
8fd6eff
b77ef40
8fd6eff
b77ef40
 
 
 
 
 
 
 
 
ec8dc2b
b77ef40
 
 
 
 
 
 
 
 
 
 
 
 
 
8fd6eff
b77ef40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from pyvis.network import Network
import gradio as gr
from transformers import pipeline
import os

model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char"

def get_graph_dict(graph_text):
    edge_labels = []
    if graph_text == "":
        edge_labels = {("No_Graphs", None):None}

    else:
        try:
            for trpl in graph_text[1:-1].split("), ("):
                h,r,t = trpl.split(" : ")
                if t == "none": t = h
                edge_labels.append((h,t, "_".join(r.split(" "))))
        except:
            edge_labels.append(("Error", None, None))
    return edge_labels

def text_to_graph(text):
    
    # Use a pipeline as a high-level helper
    pipe = pipeline(
        "text2text-generation", 
        model=model_id, 
        max_length=300,
        min_length=5,
    )

    # generate text graph
    graph_text = pipe(text)

    graph_text = graph_text[0]["generated_text"]

    # get the nodes: label dict
    edge_labels = get_graph_dict(graph_text)
    
    # create network
    net = Network(directed=True)

    # nodes & edges
    for (h, t, r) in edge_labels:
        if (h == "Error") or (h == "No_Graphs"):
            net.add_node(h, shape="circle")
            continue
        else:
            net.add_node(h, shape="circle")
            net.add_node(t, shape="circle")
            net.add_edge(h, t, title=r, label=r)

    # set structure
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )

    net.set_edge_smooth('dynamic') 

    # get html
    html = net.generate_html()
    html = html.replace("'", "\"")
    
    html_s = f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""

    return html_s, graph_text