AlzbetaStrompova commited on
Commit
ce2493d
1 Parent(s): f3898ef

fix output

Browse files
Files changed (2) hide show
  1. app.py +6 -7
  2. website_script.py +29 -5
app.py CHANGED
@@ -6,20 +6,19 @@ tokenizer, model, gazetteers_for_matching = load()
6
  print("Loaded model")
7
 
8
  examples = [
9
- "Masarykova univerzita se nachází v Brně.",
10
- "Barack Obama navštívil Prahu minulý týden.",
11
- "Angela Merkelová se setkala s francouzským prezidentem v Paříži.",
12
- "Karel Čapek napsal knihu R.U.R., která byla poprvé představena v Praze.",
13
- "Nobelova cena za fyziku byla udělena týmu vědců z MIT."
14
  ]
15
 
16
  def ner(text):
17
  result = run(tokenizer, model, gazetteers_for_matching, text)
18
- return result
19
 
20
  demo = gr.Interface(ner,
21
  gr.Textbox(placeholder="Enter sentence here..."),
22
- gr.HighlightedText(show_legend=True,),
23
  examples=examples)
24
 
25
  if __name__ == "__main__":
 
6
  print("Loaded model")
7
 
8
  examples = [
9
+ "Masarykova univerzita se nachází v Brně .",
10
+ "Barack Obama navštívil Prahu minulý týden .",
11
+ "Angela Merkelová se setkala s francouzským prezidentem v Paříži .",
12
+ "Nobelova cena za fyziku byla udělena týmu vědců z MIT ."
 
13
  ]
14
 
15
  def ner(text):
16
  result = run(tokenizer, model, gazetteers_for_matching, text)
17
+ return {"text": text, "entities": result}
18
 
19
  demo = gr.Interface(ner,
20
  gr.Textbox(placeholder="Enter sentence here..."),
21
+ gr.HighlightedText(),
22
  examples=examples)
23
 
24
  if __name__ == "__main__":
website_script.py CHANGED
@@ -24,7 +24,7 @@ def load():
24
  def run(tokenizer, model, gazetteers_for_matching, text):
25
 
26
  tokenized_inputs = tokenizer(
27
- text, truncation=True, is_split_into_words=False
28
  )
29
  matches = gazetteer_matching(text, gazetteers_for_matching)
30
  new_g = []
@@ -48,12 +48,36 @@ def run(tokenizer, model, gazetteers_for_matching, text):
48
  softmax = torch.nn.Softmax(dim=2)
49
  scores = softmax(output).squeeze(0).tolist()
50
  result = []
 
 
 
 
 
 
 
 
51
  for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
52
- result.append({
 
 
 
 
 
 
 
 
 
 
 
 
53
  "start": pos[0],
54
  "end": pos[1],
55
- "entity": entity,
56
- "score": max(score),
57
  "word": text[pos[0]:pos[1]],
58
- })
 
 
 
 
59
  return result
 
24
  def run(tokenizer, model, gazetteers_for_matching, text):
25
 
26
  tokenized_inputs = tokenizer(
27
+ text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
28
  )
29
  matches = gazetteer_matching(text, gazetteers_for_matching)
30
  new_g = []
 
48
  softmax = torch.nn.Softmax(dim=2)
49
  scores = softmax(output).squeeze(0).tolist()
50
  result = []
51
+ temp = {
52
+ "start": 0,
53
+ "end": 0,
54
+ "entity": "O",
55
+ "score": 0,
56
+ "word": "",
57
+ "count": 0
58
+ }
59
  for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
60
+ if pos[0] == pos[1] or entity == "O":
61
+ continue
62
+ if temp["entity"] == entity[2:]: # same entity
63
+ space = " " if pos[0] - temp["end"] >= 1 else ""
64
+ temp["end"] = pos[1]
65
+ temp["word"] += space + text[pos[0]:pos[1]]
66
+ temp["count"] += 1
67
+ temp["score"] += max(score)
68
+ else: # new entity
69
+ if temp["count"] > 0:
70
+ temp["score"] /= temp.pop("count")
71
+ result.append(temp)
72
+ temp = {
73
  "start": pos[0],
74
  "end": pos[1],
75
+ "entity": entity[2:],
76
+ "score": 0,
77
  "word": text[pos[0]:pos[1]],
78
+ "count": 1
79
+ }
80
+ if temp["count"] > 0:
81
+ temp["score"] /= temp.pop("count")
82
+ result.append(temp)
83
  return result