File size: 2,165 Bytes
864a486
 
 
 
 
03f1ac5
 
 
9d06087
3572a2d
 
 
864a486
 
 
 
 
9d06087
864a486
 
 
 
03f1ac5
 
864a486
03f1ac5
9d06087
864a486
 
 
 
 
9d06087
7c8f86e
864a486
9d06087
2895baa
9d06087
2895baa
9d06087
864a486
9d06087
7c8f86e
864a486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f1ac5
 
 
 
864a486
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import re
import gradio as gr
from dataclasses import dataclass
from prettytable import PrettyTable

from pytorch_ie.annotations import LabeledSpan, BinaryRelation
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument
from pytorch_ie.taskmodules import *
from pytorch_ie.models import *


from typing import List


@dataclass
class ExampleDocument(TextBasedDocument):
    entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
    relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
re_model_name_or_path = "pie/example-re-textclf-tacred"

ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0, taskmodule_kwargs=dict(create_relation_candidates=True))


def predict(text):
    document = ExampleDocument(text)

    # execute NER pipeline
    ner_pipeline(document)

    # show predicted entities and promote them from predictions to ground-truth annotations   
    print(f"detected entities:")
    for entity in document.entities.predictions:
        print(f"'{entity}', label={entity.label}, score={entity.score}")
        document.entities.append(entity.copy())

    # execute RE pipeline
    re_pipeline(document)

    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in document.relations.predictions:
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    
    return html


iface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Textbox(
        lines=5,
        default="There is still some uncertainty that Musk - also chief executive of electric car maker Tesla and rocket company SpaceX - will pull off his planned buyout.",
    ),
    outputs="html",
)
iface.launch()