File size: 12,630 Bytes
0915ad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5191cbb
0915ad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b08870
0915ad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import gradio as gr
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install("numpy")
install("torch")
install("transformers")
install("unidecode")

import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import DistilBertForTokenClassification
from collections import Counter
from unidecode import unidecode
import string
import re

auth_token = os.environ.get("NEXT_IT_TOKEN")

tokenizer = AutoTokenizer.from_pretrained("osiria/blaze-it-ner", use_auth_token=auth_token)
model = DistilBertForTokenClassification.from_pretrained("osiria/blaze-it-ner", num_labels = 5, use_auth_token=auth_token)
device = torch.device("cpu")
model = model.to(device)
model.eval()

from transformers import pipeline
ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)


header = '''--------------------------------------------------------------------------------------------------

<style>
.vertical-text {
    writing-mode: vertical-lr;
    text-orientation: upright;
    background-color:red;
}
</style>
<center>
<body>
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
</body>
</center>
<br>

--------------------------------------------------------------------------------------------------'''


paragraph = '''<b>What's BLAZE-IT?</b>

This app is a demo of [BLAZE-IT](https://huggingface.co/osiria/blaze-it), a <b>lightweight</b> and <b>uncased</b> italian language model (<b>55M parameters</b> and <b>220MB</b> size). The model is here fine-tuned for named entity recognition on WikiNER (cross-validated F1 score of 89.53%) plus a custom, hand-crafted dataset of 3.500 manually annotated Wikipedia paragraphs. 

It can recognize entities of the following types (in order to make the most of the color-coding, it is recommended to use the light theme for the interface):

- <span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> person</span>: names of persons
- <span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> location</span>: names of places
- <span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> organization</span>: names of organizations
- <span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> miscellanea</span>: mixed type entities
- <span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> date</span>: regex-based dates
- <span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ</b> tag</span>: most relevant entities, of any type

The <b>ᴍɪsᴄ</b> class has mixed nature, and it mainly covers names of events or products. Occasionally, entities of other classes might be labeled as <b>ᴍɪsᴄ</b> if the model is not confident enough about their identification.

The execution time in this app depends on the availability of the underlying cloud instance, and is not a reflection of the model inference time. 
If unknown tokens are present in the text, they will interfere with the prediction, and the model may behave erratically. In that case, a warning sign will be displayed.
'''

maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"}
reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)"
reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|"
reg_date = reg_date + reg_month + " " + "\d{4}|"
reg_date = reg_date + "\d{1,2}" + " " + reg_month
reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|"
reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|"
reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\."
map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""}
unk_tok = 9005

merge_th_1 = 0.8
merge_th_2 = 0.4
min_th = 0.6

def extract(text):

    text = text.strip()
    for mp in map_punct:
        text = text.replace(mp, map_punct[mp])
    text = re.sub("\[\d+\]", "", text)

    warn_flag = False
    
    res_total = []
    out_text = ""
    
    for p_text in text.split("\n"):
    
        if p_text:

            toks = tokenizer.encode(p_text)
            if unk_tok in toks:
                warn_flag = True
    
            res_orig = ner(p_text, aggregation_strategy = "first")
            res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1]
            res = []
            
            for r, ent in enumerate(res_orig):
                if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]:
                    res[-1]["word"] = res[-1]["word"] + " " + ent["word"]
                    res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2)
                    res[-1]["end"] = ent["end"]
                elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]:
                    res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"]
                    res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2)
                    res_orig[r+1]["start"] = ent["start"]
                else:
                    res.append(ent)
                    
            res = [el for r, el in enumerate(res) if el["score"] >= min_th]
        
            dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)]
            res.extend(dates)
            res = sorted(res, key = lambda t: t["start"])
            res_total.extend(res)
    
            chunks = [("", "", 0, "NONE")]
    
            for el in res:
                if maps[el["entity_group"]] != "NONE":
                    tag = maps[el["entity_group"]]
                    chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag))

            if chunks[-1][2] < len(p_text):
                chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE"))
            chunks = chunks[1:]
            
            n_text = []
    
            for i, chunk in enumerate(chunks):

                rep = chunk[0]
    
                if chunk[3] == "PER":
                    rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "LOC":
                    rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "ORG":
                    rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "MISC":
                    rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "DATE":
                    rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>'
    
                n_text.append(chunk[1].replace(chunk[0], rep))
    
            n_text = "".join(n_text)
            if out_text:
                out_text = out_text + "<br>" + n_text
            else:
                out_text = n_text
    

    tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]]
    cnt = Counter(tags)
    tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1]
    tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags]
    tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags]
    tags = "    ".join(tags)

    if tags:
        out_text = out_text + "<br><br><b>Tags:</b> " + tags

    if warn_flag:
        out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text.  The model might behave erratically"
    
    return out_text



init_text = '''l'agenzia spaziale europea, nota internazionalmente con l'acronimo esa dalla denominazione inglese european space agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 paesi europei. il suo quartier generale si trova a parigi in francia, con uffici a mosca, bruxelles, washington e houston.
attualmente il direttore generale dell'agenzia è l'austriaco josef aschbacher, il quale ha sostituito il tedesco johann-dietrich wörner il primo marzo 2021.

lo spazioporto dell'esa è il centre spatial guyanais a kourou, nella guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. durante gli ultimi anni il lanciatore ariane 5 ha consentito all'esa di raggiungere una posizione di primo piano nei lanci commerciali e l'esa è il principale concorrente della nasa nell'esplorazione spaziale.

le missioni scientifiche dell'esa hanno le loro basi al centro europeo per la ricerca e la tecnologia spaziale (estec) di noordwijk, nei paesi bassi. il centro europeo per le operazioni spaziali (esoc), di darmstadt in germania, è responsabile del controllo dei satelliti esa in orbita.  [...]

l'agenzia spaziale italiana (asi) venne fondata nel 1988 per promuovere, coordinare e condurre le attività spaziali in italia. opera in collaborazione con il ministero dell'università e della ricerca scientifica e coopera in numerosi progetti con entità attive nella ricerca scientifica e nelle attività commerciali legate allo spazio. internazionalmente l'asi fornisce la delegazione italiana per l'agenzia spaziale europea e le sue sussidiarie.'''

init_output = extract(init_text)




with gr.Blocks(theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
    
    with gr.Row():
        gr.Markdown(header)  
    with gr.Row():
        with gr.Column():
            gr.Markdown(paragraph) 
        with gr.Column():
            incipit = gr.Markdown("<b>Highlighted entities<b>")
            entities = gr.Markdown(init_output)
            
         
    with gr.Row():
        with gr.Column():
            text = gr.Text(label="Extract entities", lines = 10, value = init_text)
        with gr.Column():
            gr.Examples([["aristotele nacque nel 384 a.c. o nel 383 a.c. a stagira, l'attuale stavro, colonia greca situata nella parte nord-orientale della penisola calcidica della tracia. si dice che il padre, nicomaco, sia vissuto presso aminta iii, re dei macedoni, prestandogli i servigi di medico e di amico. aristotele, come figlio del medico reale, doveva pertanto risiedere nella capitale del regno di macedonia"],
                         ["mi chiamo edoardo, vivo a roma e lavoro per l'agenzia spaziale italiana, nella missione prisma"],
                         ["wikipedia è un'enciclopedia online a contenuto libero, collaborativa, multilingue e gratuita, nata nel 2001, sostenuta e ospitata dalla wikimedia foundation, un'organizzazione non a scopo di lucro statunitense.  lanciata da jimmy wales e larry sanger il 15 gennaio 2001, inizialmente nell'edizione in lingua inglese, nei mesi successivi ha aggiunto edizioni in numerose altre lingue"]],
                         inputs=[text])
    with gr.Row():
        button = gr.Button("Extract").style(full_width=False)

    with gr.Row():
        with gr.Column():
            gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>") 

    button.click(extract, inputs=[text], outputs = [entities])


interface.launch()