from typing import Dict, Any
import spacy
from PIL import ImageFont
from spacy.tokens import Doc
def get_pil_text_size(text, font_size, font_name):
font = ImageFont.truetype(font_name, font_size)
size = font.getsize(text)
return size
def render_arrow(
label: str, start: int, end: int, direction: str, i: int
) -> str:
"""Render individual arrow.
label (str): Dependency label.
start (int): Index of start word.
end (int): Index of end word.
direction (str): Arrow direction, 'left' or 'right'.
i (int): Unique ID, typically arrow index.
RETURNS (str): Rendered SVG markup.
"""
TPL_DEP_ARCS = """
{label}
"""
arc = get_arc(start + 20, 50, 5, end + 20)
arrowhead = get_arrowhead(direction, start + 20, 50, end + 20)
label_side = "right" if direction == "rtl" else "left"
return TPL_DEP_ARCS.format(
id=0,
i=0,
stroke=2,
head=arrowhead,
label=label,
label_side=label_side,
arc=arc,
)
def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str:
"""Render individual arc.
x_start (int): X-coordinate of arrow start point.
y (int): Y-coordinate of arrow start and end point.
y_curve (int): Y-corrdinate of Cubic Bézier y_curve point.
x_end (int): X-coordinate of arrow end point.
RETURNS (str): Definition of the arc path ('d' attribute).
"""
template = "M{x},{y} C{x},{c} {e},{c} {e},{y}"
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
def get_arrowhead(direction: str, x: int, y: int, end: int) -> str:
"""Render individual arrow head.
direction (str): Arrow direction, 'left' or 'right'.
x (int): X-coordinate of arrow start point.
y (int): Y-coordinate of arrow start and end point.
end (int): X-coordinate of arrow end point.
RETURNS (str): Definition of the arrow head path ('d' attribute).
"""
arrow_width = 6
if direction == "left":
p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2)
else:
p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2)
return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}"
# parsed = [{'words': [{'text': 'The', 'tag': 'DET', 'lemma': None}, {'text': 'OnePlus', 'tag': 'PROPN', 'lemma': None}, {'text': '10', 'tag': 'NUM', 'lemma': None}, {'text': 'Pro', 'tag': 'PROPN', 'lemma': None}, {'text': 'is', 'tag': 'AUX', 'lemma': None}, {'text': 'the', 'tag': 'DET', 'lemma': None}, {'text': 'company', 'tag': 'NOUN', 'lemma': None}, {'text': "'s", 'tag': 'PART', 'lemma': None}, {'text': 'first', 'tag': 'ADJ', 'lemma': None}, {'text': 'flagship', 'tag': 'NOUN', 'lemma': None}, {'text': 'phone.', 'tag': 'NOUN', 'lemma': None}], 'arcs': [{'start': 0, 'end': 3, 'label': 'det', 'dir': 'left'}, {'start': 1, 'end': 3, 'label': 'nmod', 'dir': 'left'}, {'start': 1, 'end': 2, 'label': 'nummod', 'dir': 'right'}, {'start': 3, 'end': 4, 'label': 'nsubj', 'dir': 'left'}, {'start': 5, 'end': 6, 'label': 'det', 'dir': 'left'}, {'start': 6, 'end': 10, 'label': 'poss', 'dir': 'left'}, {'start': 6, 'end': 7, 'label': 'case', 'dir': 'right'}, {'start': 8, 'end': 10, 'label': 'amod', 'dir': 'left'}, {'start': 9, 'end': 10, 'label': 'compound', 'dir': 'left'}, {'start': 4, 'end': 10, 'label': 'attr', 'dir': 'right'}], 'settings': {'lang': 'en', 'direction': 'ltr'}}]
def render_sentence_custom(parsed: str):
TPL_DEP_WORDS = """
{text}
{tag}
"""
TPL_DEP_SVG = """
"""
arcs_svg = []
couples = []
nlp = spacy.load('en_core_web_sm')
doc = nlp(parsed)
arcs = {}
words = {}
parsed = [parse_deps(doc)]
for i, p in enumerate(parsed):
arcs = p["arcs"]
words = p["words"]
for i, a in enumerate(arcs):
if a["label"] == "amod":
couples = (a["start"], a["end"])
print(couples)
x_value_counter = 10
index_counter = 0
svg_words = []
coords_test = []
for i, word in enumerate(words):
word = word["text"]
word = word + " "
pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0]
svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70))
print(index_counter)
if index_counter >= couples[0] and index_counter <= couples[1]:
coords_test.append(x_value_counter)
x_value_counter += 50
index_counter += 1
x_value_counter += pixel_x_length + 4
print(coords_test)
for i, a in enumerate(arcs):
if a["label"] == "amod":
arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i))
content = "".join(svg_words) + "".join(arcs_svg)
full_svg = TPL_DEP_SVG.format(
id=0,
width=1975,
height=574.5,
color="#00000",
bg="#ffffff",
font="Arial",
content=content,
dir="ltr",
lang="en",
)
return full_svg
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Generate dependency parse in {'words': [], 'arcs': []} format.
doc (Doc): Document do parse.
RETURNS (dict): Generated dependency parse keyed by words and arcs.
"""
doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data"]))
if not doc.has_annotation("DEP"):
print("WARNING")
if options.get("collapse_phrases", False):
with doc.retokenize() as retokenizer:
for np in list(doc.noun_chunks):
attrs = {
"tag": np.root.tag_,
"lemma": np.root.lemma_,
"ent_type": np.root.ent_type_,
}
retokenizer.merge(np, attrs=attrs)
if options.get("collapse_punct", True):
spans = []
for word in doc[:-1]:
if word.is_punct or not word.nbor(1).is_punct:
continue
start = word.i
end = word.i + 1
while end < len(doc) and doc[end].is_punct:
end += 1
span = doc[start:end]
spans.append((span, word.tag_, word.lemma_, word.ent_type_))
with doc.retokenize() as retokenizer:
for span, tag, lemma, ent_type in spans:
attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type}
retokenizer.merge(span, attrs=attrs)
fine_grained = options.get("fine_grained")
add_lemma = options.get("add_lemma")
words = [
{
"text": w.text,
"tag": w.tag_ if fine_grained else w.pos_,
"lemma": w.lemma_ if add_lemma else None,
}
for w in doc
]
arcs = []
for word in doc:
if word.i < word.head.i:
arcs.append(
{"start": word.i, "end": word.head.i, "label": word.dep_, "dir": "left"}
)
elif word.i > word.head.i:
arcs.append(
{
"start": word.head.i,
"end": word.i,
"label": word.dep_,
"dir": "right",
}
)
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
def get_doc_settings(doc: Doc) -> Dict[str, Any]:
return {
"lang": doc.lang_,
"direction": doc.vocab.writing_system.get("direction", "ltr"),
}