from typing import Dict, Any import spacy from PIL import ImageFont from spacy.tokens import Doc def get_pil_text_size(text, font_size, font_name): font = ImageFont.truetype(font_name, font_size) size = font.getsize(text) return size def render_arrow( label: str, start: int, end: int, direction: str, i: int ) -> str: """Render individual arrow. label (str): Dependency label. start (int): Index of start word. end (int): Index of end word. direction (str): Arrow direction, 'left' or 'right'. i (int): Unique ID, typically arrow index. RETURNS (str): Rendered SVG markup. """ TPL_DEP_ARCS = """ {label} """ arc = get_arc(start + 20, 50, 5, end + 20) arrowhead = get_arrowhead(direction, start + 20, 50, end + 20) label_side = "right" if direction == "rtl" else "left" return TPL_DEP_ARCS.format( id=0, i=0, stroke=2, head=arrowhead, label=label, label_side=label_side, arc=arc, ) def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str: """Render individual arc. x_start (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. x_end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arc path ('d' attribute). """ template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" return template.format(x=x_start, y=y, c=y_curve, e=x_end) def get_arrowhead(direction: str, x: int, y: int, end: int) -> str: """Render individual arrow head. direction (str): Arrow direction, 'left' or 'right'. x (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arrow head path ('d' attribute). """ arrow_width = 6 if direction == "left": p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2) else: p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2) return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}" # parsed = [{'words': [{'text': 'The', 'tag': 'DET', 'lemma': None}, {'text': 'OnePlus', 'tag': 'PROPN', 'lemma': None}, {'text': '10', 'tag': 'NUM', 'lemma': None}, {'text': 'Pro', 'tag': 'PROPN', 'lemma': None}, {'text': 'is', 'tag': 'AUX', 'lemma': None}, {'text': 'the', 'tag': 'DET', 'lemma': None}, {'text': 'company', 'tag': 'NOUN', 'lemma': None}, {'text': "'s", 'tag': 'PART', 'lemma': None}, {'text': 'first', 'tag': 'ADJ', 'lemma': None}, {'text': 'flagship', 'tag': 'NOUN', 'lemma': None}, {'text': 'phone.', 'tag': 'NOUN', 'lemma': None}], 'arcs': [{'start': 0, 'end': 3, 'label': 'det', 'dir': 'left'}, {'start': 1, 'end': 3, 'label': 'nmod', 'dir': 'left'}, {'start': 1, 'end': 2, 'label': 'nummod', 'dir': 'right'}, {'start': 3, 'end': 4, 'label': 'nsubj', 'dir': 'left'}, {'start': 5, 'end': 6, 'label': 'det', 'dir': 'left'}, {'start': 6, 'end': 10, 'label': 'poss', 'dir': 'left'}, {'start': 6, 'end': 7, 'label': 'case', 'dir': 'right'}, {'start': 8, 'end': 10, 'label': 'amod', 'dir': 'left'}, {'start': 9, 'end': 10, 'label': 'compound', 'dir': 'left'}, {'start': 4, 'end': 10, 'label': 'attr', 'dir': 'right'}], 'settings': {'lang': 'en', 'direction': 'ltr'}}] def render_sentence_custom(parsed: str): TPL_DEP_WORDS = """ {text} {tag} """ TPL_DEP_SVG = """ {content} """ arcs_svg = [] couples = [] nlp = spacy.load('en_core_web_sm') doc = nlp(parsed) arcs = {} words = {} parsed = [parse_deps(doc)] for i, p in enumerate(parsed): arcs = p["arcs"] words = p["words"] for i, a in enumerate(arcs): if a["label"] == "amod": couples = (a["start"], a["end"]) print(couples) x_value_counter = 10 index_counter = 0 svg_words = [] coords_test = [] for i, word in enumerate(words): word = word["text"] word = word + " " pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) print(index_counter) if index_counter >= couples[0] and index_counter <= couples[1]: coords_test.append(x_value_counter) x_value_counter += 50 index_counter += 1 x_value_counter += pixel_x_length + 4 print(coords_test) for i, a in enumerate(arcs): if a["label"] == "amod": arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i)) content = "".join(svg_words) + "".join(arcs_svg) full_svg = TPL_DEP_SVG.format( id=0, width=1975, height=574.5, color="#00000", bg="#ffffff", font="Arial", content=content, dir="ltr", lang="en", ) return full_svg def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: """Generate dependency parse in {'words': [], 'arcs': []} format. doc (Doc): Document do parse. RETURNS (dict): Generated dependency parse keyed by words and arcs. """ doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data"])) if not doc.has_annotation("DEP"): print("WARNING") if options.get("collapse_phrases", False): with doc.retokenize() as retokenizer: for np in list(doc.noun_chunks): attrs = { "tag": np.root.tag_, "lemma": np.root.lemma_, "ent_type": np.root.ent_type_, } retokenizer.merge(np, attrs=attrs) if options.get("collapse_punct", True): spans = [] for word in doc[:-1]: if word.is_punct or not word.nbor(1).is_punct: continue start = word.i end = word.i + 1 while end < len(doc) and doc[end].is_punct: end += 1 span = doc[start:end] spans.append((span, word.tag_, word.lemma_, word.ent_type_)) with doc.retokenize() as retokenizer: for span, tag, lemma, ent_type in spans: attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type} retokenizer.merge(span, attrs=attrs) fine_grained = options.get("fine_grained") add_lemma = options.get("add_lemma") words = [ { "text": w.text, "tag": w.tag_ if fine_grained else w.pos_, "lemma": w.lemma_ if add_lemma else None, } for w in doc ] arcs = [] for word in doc: if word.i < word.head.i: arcs.append( {"start": word.i, "end": word.head.i, "label": word.dep_, "dir": "left"} ) elif word.i > word.head.i: arcs.append( { "start": word.head.i, "end": word.i, "label": word.dep_, "dir": "right", } ) return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} def get_doc_settings(doc: Doc) -> Dict[str, Any]: return { "lang": doc.lang_, "direction": doc.vocab.writing_system.get("direction", "ltr"), }