File size: 2,068 Bytes
b77ef40
 
 
 
 
4eb245b
b77ef40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec8dc2b
b77ef40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pyvis.network import Network
import gradio as gr
from transformers import pipeline
import os

model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char"

def get_graph_dict(graph_text):
    edge_labels = {}
    if graph_text == "":
        edge_labels = {("No_Graphs", None):None}

    else:
        try:
            for trpl in graph_text[1:-1].split(" | "):
                h,r,t = trpl[1:-1].split(" # ") 
                edge_labels[(h,t)] = r
        except:
            edge_labels = {("Error", None):None}
    return edge_labels

def text_to_graph(text):
    
    # Use a pipeline as a high-level helper
    pipe = pipeline(
        "text2text-generation", 
        model=model_id, 
        max_length=300,
        min_length=5,
    )

    # generate text graph
    graph_text = pipe(text)

    graph_text = graph_text[0]["generated_text"]

    # get the nodes: label dict
    edge_labels = get_graph_dict(graph_text)
    
    # create network
    net = Network(directed=True)

    # nodes & edges
    for (h, t), r  in edge_labels.items():
        if (h == "Error") or (h == "No_Graphs"):
            net.add_node(h, shape="circle")
            continue
        else:
            net.add_node(h, shape="circle")
            net.add_node(t, shape="circle")
            net.add_edge(h, t, title=r, label=r)

    # set structure
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )

    net.set_edge_smooth('dynamic') 

    # get html
    html = net.generate_html()
    html = html.replace("'", "\"")
    
    html_s = f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""

    return html_s, graph_text