Spaces:
Sleeping
Sleeping
from typing import NoReturn | |
import spacy | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import gradio as gr | |
# Load the spaCy model for dependency parsing | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
# Download the model in case it's not available | |
from spacy.cli import download | |
download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# Function to extract entities using NER | |
def extract_entities(text): | |
doc = nlp(text) | |
entities = [(ent.text, ent.label_) for ent in doc.ents] | |
return entities | |
# Function to extract relationships dynamically from the text | |
def extract_relationships(text): | |
relationships = [] | |
doc = nlp(text.lower()) | |
subject, verb, obj, Noun = None, None, None, None | |
entities = [] | |
for token in doc: | |
if token.dep_ in ("compound"): | |
Noun = token.text + " " | |
continue | |
if not Noun: | |
if token.dep_ in ("nsubj", "nsubjpass"): | |
subject = token.text | |
if token.dep_ in ("dobj", "attr", "pobj"): | |
obj = token.text | |
entities.append(obj) | |
if token.dep_ in ("ROOT", "xcomp", "ccomp"): | |
verb = token.text | |
elif Noun: | |
if token.dep_ in ("nsubj", "nsubjpass"): | |
subject = Noun | |
entities.append(subject) | |
if token.dep_ in ("dobj", "attr", "pobj"): | |
obj = Noun | |
entities.append(obj) | |
Noun = None | |
if token.dep_ == "prep": | |
subject = entities[-1] | |
if token.head.dep_ == "ROOT": | |
verb = token.head.text + " " + token.text | |
else: | |
verb = token.text | |
if subject and verb and obj: | |
relationships.append((subject.strip(), verb.strip(), obj.strip())) | |
subject, verb, obj = None, None, None | |
return relationships, entities | |
# Function to create the knowledge graph | |
def create_knowledge_graph(entities, relationships): | |
G = nx.DiGraph() | |
involved_entities = set() | |
for subj, rel, obj in relationships: | |
involved_entities.add(subj) | |
involved_entities.add(obj) | |
for entity in involved_entities: | |
G.add_node(entity) | |
for subj, rel, obj in relationships: | |
G.add_edge(subj, obj, label=rel) | |
return G | |
# Function to visualize the graph | |
def visualize_graph(G): | |
pos = nx.spring_layout(G) | |
edge_labels = nx.get_edge_attributes(G, 'label') | |
plt.figure(figsize=(12, 8)) | |
nx.draw(G, pos, with_labels=True, node_size=2000, node_color="lightblue", font_size=10, font_weight="bold") | |
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
plt.close() | |
pil_image = Image.open(buf) | |
return pil_image | |
# Function to process input and generate output | |
def process_text(text: str): | |
relationships, entities = extract_relationships(text) | |
G = create_knowledge_graph(entities, relationships) | |
return visualize_graph(G) | |
# Gradio Interface | |
gr.Interface( | |
fn=process_text, | |
inputs=gr.Textbox(placeholder="Enter knowledge prompt here"), | |
outputs=gr.Image(type="pil"), | |
title="Knowledge Graph Generator" | |
).launch() | |