D3V1L1810's picture
Update app.py
05f9e0c verified
from typing import NoReturn
import spacy
import networkx as nx
import matplotlib.pyplot as plt
import io
from PIL import Image
import gradio as gr
# Load the spaCy model for dependency parsing
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
# Download the model in case it's not available
from spacy.cli import download
download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# Function to extract entities using NER
def extract_entities(text):
doc = nlp(text)
entities = [(ent.text, ent.label_) for ent in doc.ents]
return entities
# Function to extract relationships dynamically from the text
def extract_relationships(text):
relationships = []
doc = nlp(text.lower())
subject, verb, obj, Noun = None, None, None, None
entities = []
for token in doc:
if token.dep_ in ("compound"):
Noun = token.text + " "
continue
if not Noun:
if token.dep_ in ("nsubj", "nsubjpass"):
subject = token.text
if token.dep_ in ("dobj", "attr", "pobj"):
obj = token.text
entities.append(obj)
if token.dep_ in ("ROOT", "xcomp", "ccomp"):
verb = token.text
elif Noun:
if token.dep_ in ("nsubj", "nsubjpass"):
subject = Noun
entities.append(subject)
if token.dep_ in ("dobj", "attr", "pobj"):
obj = Noun
entities.append(obj)
Noun = None
if token.dep_ == "prep":
subject = entities[-1]
if token.head.dep_ == "ROOT":
verb = token.head.text + " " + token.text
else:
verb = token.text
if subject and verb and obj:
relationships.append((subject.strip(), verb.strip(), obj.strip()))
subject, verb, obj = None, None, None
return relationships, entities
# Function to create the knowledge graph
def create_knowledge_graph(entities, relationships):
G = nx.DiGraph()
involved_entities = set()
for subj, rel, obj in relationships:
involved_entities.add(subj)
involved_entities.add(obj)
for entity in involved_entities:
G.add_node(entity)
for subj, rel, obj in relationships:
G.add_edge(subj, obj, label=rel)
return G
# Function to visualize the graph
def visualize_graph(G):
pos = nx.spring_layout(G)
edge_labels = nx.get_edge_attributes(G, 'label')
plt.figure(figsize=(12, 8))
nx.draw(G, pos, with_labels=True, node_size=2000, node_color="lightblue", font_size=10, font_weight="bold")
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close()
pil_image = Image.open(buf)
return pil_image
# Function to process input and generate output
def process_text(text: str):
relationships, entities = extract_relationships(text)
G = create_knowledge_graph(entities, relationships)
return visualize_graph(G)
# Gradio Interface
gr.Interface(
fn=process_text,
inputs=gr.Textbox(placeholder="Enter knowledge prompt here"),
outputs=gr.Image(type="pil"),
title="Knowledge Graph Generator"
).launch()