Spaces:
Sleeping
Sleeping
from pyvis.network import Network | |
import gradio as gr | |
from transformers import pipeline | |
import os | |
model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char" | |
def get_graph_dict(graph_text): | |
edge_labels = {} | |
if graph_text == "": | |
edge_labels = {("No_Graphs", None):None} | |
else: | |
try: | |
for trpl in graph_text[1:-1].split(" | "): | |
h,r,t = trpl[1:-1].split(" # ") | |
edge_labels[(h,t)] = r | |
except: | |
edge_labels = {("Error", None):None} | |
return edge_labels | |
def text_to_graph(text): | |
# Use a pipeline as a high-level helper | |
pipe = pipeline( | |
"text2text-generation", | |
model=model_id, | |
max_length=300, | |
min_length=5, | |
) | |
# generate text graph | |
graph_text = pipe(text) | |
graph_text = graph_text[0]["generated_text"] | |
# get the nodes: label dict | |
edge_labels = get_graph_dict(graph_text) | |
# create network | |
net = Network(directed=True) | |
# nodes & edges | |
for (h, t), r in edge_labels.items(): | |
if (h == "Error") or (h == "No_Graphs"): | |
net.add_node(h, shape="circle") | |
continue | |
else: | |
net.add_node(h, shape="circle") | |
net.add_node(t, shape="circle") | |
net.add_edge(h, t, title=r, label=r) | |
# set structure | |
net.repulsion( | |
node_distance=200, | |
central_gravity=0.2, | |
spring_length=200, | |
spring_strength=0.05, | |
damping=0.09 | |
) | |
net.set_edge_smooth('dynamic') | |
# get html | |
html = net.generate_html() | |
html = html.replace("'", "\"") | |
html_s = f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
return html_s, graph_text | |