File size: 2,082 Bytes
1d7e221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import re
import gradio as gr
from dataclasses import dataclass
from prettytable import PrettyTable

from pytorch_ie import AnnotationList, BinaryRelation, Span, LabeledSpan, Pipeline, TextDocument, annotation_field
from pytorch_ie.models import TransformerSpanClassificationModel, TransformerTextClassificationModel
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule, TransformerRETextClassificationTaskModule

from typing import List


@dataclass
class ExampleDocument(TextDocument):
    entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
    relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


model_name_or_path = "pie/example-ner-spanclf-conll03"
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)

ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)

model_name_or_path = "pie/example-re-textclf-tacred"
re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(model_name_or_path)
re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)

re_pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1, num_workers=0)


def predict(text):
    document = ExampleDocument(text)

    ner_pipeline(document, predict_field="entities")

    for entity in document.entities.predictions:
        document.entities.append(entity)

    re_pipeline(document, predict_field="relations")

    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in document.relations.predictions:
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    
    return html


iface = gr.Interface(
    fn=predict,
    inputs="textbox",
    outputs="html",
)
iface.launch()