|
import ast |
|
import os |
|
from copy import deepcopy |
|
|
|
import dhg |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
from dhg.visualization.structure.defaults import (default_hypergraph_strength, |
|
default_hypergraph_style, |
|
default_size) |
|
from dhg.visualization.structure.layout import force_layout |
|
from dhg.visualization.structure.utils import draw_circle_edge, draw_vertex |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def draw_hypergraph( |
|
hg: "dhg.Hypergraph", |
|
e_style="circle", |
|
v_label=None, |
|
v_size=1.0, |
|
v_color="r", |
|
v_line_width=1.0, |
|
e_color="gray", |
|
e_fill_color="whitesmoke", |
|
e_line_width=1.0, |
|
font_size=1.0, |
|
font_family="sans-serif", |
|
push_v_strength=1.0, |
|
push_e_strength=1.0, |
|
pull_e_strength=1.0, |
|
pull_center_strength=1.0, |
|
): |
|
fig, ax = plt.subplots(figsize=(6, 6)) |
|
|
|
num_v, e_list = hg.num_v, deepcopy(hg.e[0]) |
|
|
|
v_color, e_color, e_fill_color = default_hypergraph_style( |
|
hg.num_v, hg.num_e, v_color, e_color, e_fill_color |
|
) |
|
v_size, v_line_width, e_line_width, font_size = default_size( |
|
num_v, e_list, v_size, v_line_width, e_line_width |
|
) |
|
( |
|
push_v_strength, |
|
push_e_strength, |
|
pull_e_strength, |
|
pull_center_strength, |
|
) = default_hypergraph_strength( |
|
num_v, |
|
e_list, |
|
push_v_strength, |
|
push_e_strength, |
|
pull_e_strength, |
|
pull_center_strength, |
|
) |
|
|
|
v_coor = force_layout( |
|
num_v, |
|
e_list, |
|
push_v_strength, |
|
push_e_strength, |
|
pull_e_strength, |
|
pull_center_strength, |
|
) |
|
draw_circle_edge( |
|
ax, |
|
v_coor, |
|
v_size, |
|
e_list, |
|
e_color, |
|
e_fill_color, |
|
e_line_width, |
|
) |
|
|
|
draw_vertex( |
|
ax, |
|
v_coor, |
|
v_label, |
|
font_size, |
|
font_family, |
|
v_size, |
|
v_color, |
|
v_line_width, |
|
) |
|
|
|
plt.xlim((0, 1.0)) |
|
plt.ylim((0, 1.0)) |
|
plt.axis("off") |
|
fig.tight_layout() |
|
|
|
return fig |
|
|
|
|
|
def plot_dataset(dataset_choice: str, sampling_choice: str, split_choice: str): |
|
os.makedirs("artifacts", exist_ok=True) |
|
hf_hub_download( |
|
filename=f"processed/{sampling_choice}/{split_choice}_df.csv", |
|
local_dir="./artifacts/", |
|
repo_id=f"SauravMaheshkar/{dataset_choice}", |
|
repo_type="dataset", |
|
) |
|
|
|
df = pd.read_csv(f"artifacts/processed/{sampling_choice}/{split_choice}_df.csv") |
|
|
|
num_vertices = len(df) |
|
edge_list = df["nodes"].values.tolist() |
|
edge_list = [ast.literal_eval(edges) for edges in edge_list] |
|
|
|
hypergraph = dhg.Hypergraph(num_vertices, edge_list) |
|
|
|
fig = draw_hypergraph(hypergraph) |
|
return fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(): |
|
dataset_choices = gr.Dropdown( |
|
choices=[ |
|
"email-Eu", |
|
"email-Enron", |
|
"NDC-classes", |
|
"tags-math-sx", |
|
"email-Eu-25", |
|
"NDC-substances", |
|
"congress-bills", |
|
"tags-ask-ubuntu", |
|
"email-Enron-25", |
|
"NDC-classes-25", |
|
"threads-ask-ubuntu", |
|
"contact-high-school", |
|
"NDC-substances-25", |
|
"congress-bills-25", |
|
"contact-primary-school", |
|
], |
|
value="email-Enron-25", |
|
label="Please choose a dataset", |
|
interactive=True, |
|
) |
|
|
|
sampling_choice = gr.Dropdown( |
|
choices=[ |
|
"transductive", |
|
"inductive", |
|
], |
|
value="inductive", |
|
label="Choose sampling type", |
|
interactive=True, |
|
) |
|
|
|
split_choice = gr.Dropdown( |
|
choices=[ |
|
"train", |
|
"valid", |
|
"test", |
|
], |
|
value="test", |
|
label="Choose split", |
|
interactive=True, |
|
) |
|
|
|
output_plot = gr.Plot(label="Hypergraph plot") |
|
|
|
btn = gr.Button("Visualise") |
|
btn.click( |
|
fn=plot_dataset, |
|
inputs=[dataset_choices, sampling_choice, split_choice], |
|
outputs=output_plot, |
|
) |
|
|
|
demo.launch() |
|
|