Emanuel commited on
Commit
6875045
1 Parent(s): 6447205

Added model selection feature

Browse files
Files changed (1) hide show
  1. app.py +48 -13
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import logging
2
  import os
3
- import tempfile
4
  from pathlib import Path
5
  from typing import List, Tuple
6
 
@@ -8,6 +7,7 @@ import gradio as gr
8
  import pandas as pd
9
  import spacy
10
  import torch
 
11
  from transformers import AutoModelForTokenClassification, AutoTokenizer
12
 
13
  from preprocessing import expand_contractions
@@ -17,27 +17,56 @@ try:
17
  except Exception:
18
  os.system("python -m spacy download pt_core_news_sm")
19
  nlp = spacy.load("pt_core_news_sm")
20
-
21
- model = AutoModelForTokenClassification.from_pretrained("Emanuel/porttagger-news-base")
22
- tokenizer = AutoTokenizer.from_pretrained("Emanuel/porttagger-news-base")
 
 
 
 
 
 
 
 
 
 
 
 
23
  logger = logging.getLogger()
24
  logger.setLevel(logging.DEBUG)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def predict(text, nlp, logger=None) -> Tuple[List[str], List[str]]:
28
- doc = nlp(text)
29
- tokens = [token.text for token in doc]
30
 
31
  logger.info("Starting predictions for sentence: {}".format(text))
 
32
 
33
- input_tokens = tokenizer(
34
  tokens,
35
  return_tensors="pt",
36
  is_split_into_words=True,
37
  return_offsets_mapping=True,
38
  return_special_tokens_mask=True,
39
  )
40
- output = model(input_tokens["input_ids"])
41
 
42
  i_token = 0
43
  labels = []
@@ -49,7 +78,7 @@ def predict(text, nlp, logger=None) -> Tuple[List[str], List[str]]:
49
  ):
50
  if is_special_token or off[0] > 0:
51
  continue
52
- label = model.config.__dict__["id2label"][int(pred.argmax(axis=-1))]
53
  if logger is not None:
54
  logger.info("{}, {}, {}".format(off, tokens[i_token], label))
55
  labels.append(label)
@@ -63,7 +92,11 @@ def predict(text, nlp, logger=None) -> Tuple[List[str], List[str]]:
63
 
64
  def text_analysis(text):
65
  text = expand_contractions(text)
66
- tokens, labels, scores = predict(text, nlp, logger)
 
 
 
 
67
  pos_count = pd.DataFrame(
68
  {
69
  "token": tokens,
@@ -99,7 +132,7 @@ def batch_analysis(input_file):
99
  sent = expand_contractions(sent)
100
  conllu_output.append("# sent_id = {}-{}\n".format(name, i + 1))
101
  conllu_output.append("# text = {}\n".format(sent))
102
- tokens, labels, scores = predict(sent, nlp, logger)
103
  for j, (token, label) in enumerate(zip(tokens, labels)):
104
  conllu_output.append(
105
  "{}\t{}\t_\t{}".format(j + 1, token, label) + "\t_" * 5 + "\n"
@@ -119,6 +152,8 @@ bottom_html = open("bottom.html").read()
119
 
120
  with gr.Blocks(css=css) as demo:
121
  gr.HTML(top_html)
 
 
122
  with gr.Tab("Single sentence"):
123
  text = gr.Textbox(placeholder="Enter your text here...", label="Input")
124
  examples = gr.Examples(
@@ -167,4 +202,4 @@ with gr.Blocks(css=css) as demo:
167
  gr.HTML(bottom_html)
168
 
169
 
170
- demo.launch(debug=True, server_port=15000)
 
1
  import logging
2
  import os
 
3
  from pathlib import Path
4
  from typing import List, Tuple
5
 
 
7
  import pandas as pd
8
  import spacy
9
  import torch
10
+ from dante_tokenizer import DanteTokenizer
11
  from transformers import AutoModelForTokenClassification, AutoTokenizer
12
 
13
  from preprocessing import expand_contractions
 
17
  except Exception:
18
  os.system("python -m spacy download pt_core_news_sm")
19
  nlp = spacy.load("pt_core_news_sm")
20
+ dt_tokenizer = DanteTokenizer()
21
+
22
+ default_model = "News"
23
+ model_choices = {
24
+ "News": "Emanuel/porttagger-news-base",
25
+ "Tweets": "Emanuel/porttagger-tweets-base",
26
+ "Oil and Gas": "Emanuel/porttagger-oilgas-base",
27
+ "Multigenre": "Emanuel/porttagger-base",
28
+ }
29
+ pre_tokenizers = {
30
+ "News": nlp,
31
+ "Tweets": dt_tokenizer.tokenize,
32
+ "Oil and Gas": nlp,
33
+ "Multigenre": nlp,
34
+ }
35
  logger = logging.getLogger()
36
  logger.setLevel(logging.DEBUG)
37
 
38
+ class MyApp:
39
+ def __init__(self) -> None:
40
+ self.model = None
41
+ self.tokenizer = None
42
+ self.pre_tokenizer = None
43
+ self.load_model()
44
+
45
+ def load_model(self, model_name: str = default_model):
46
+ if model_name not in model_choices.keys():
47
+ logger.error("Selected model is not supported, resetting to the default model.")
48
+ model_name = default_model
49
+ self.model = AutoModelForTokenClassification.from_pretrained(model_choices[model_name])
50
+ self.tokenizer = AutoTokenizer.from_pretrained(model_choices[model_name])
51
+ self.pre_tokenizer = pre_tokenizers[model_name]
52
+
53
+ myapp = MyApp()
54
 
55
+ def predict(text, logger=None) -> Tuple[List[str], List[str]]:
56
+ doc = myapp.pre_tokenizer(text)
57
+ tokens = [token.text if not isinstance(token, str) else token for token in doc]
58
 
59
  logger.info("Starting predictions for sentence: {}".format(text))
60
+ print("Using model {}".format(myapp.model.config.__dict__["_name_or_path"]))
61
 
62
+ input_tokens = myapp.tokenizer(
63
  tokens,
64
  return_tensors="pt",
65
  is_split_into_words=True,
66
  return_offsets_mapping=True,
67
  return_special_tokens_mask=True,
68
  )
69
+ output = myapp.model(input_tokens["input_ids"])
70
 
71
  i_token = 0
72
  labels = []
 
78
  ):
79
  if is_special_token or off[0] > 0:
80
  continue
81
+ label = myapp.model.config.__dict__["id2label"][int(pred.argmax(axis=-1))]
82
  if logger is not None:
83
  logger.info("{}, {}, {}".format(off, tokens[i_token], label))
84
  labels.append(label)
 
92
 
93
  def text_analysis(text):
94
  text = expand_contractions(text)
95
+ tokens, labels, scores = predict(text, logger)
96
+ if len(labels) != len(tokens):
97
+ m = len(tokens) - len(labels)
98
+ labels += [None] * m
99
+ scores += [0] * m
100
  pos_count = pd.DataFrame(
101
  {
102
  "token": tokens,
 
132
  sent = expand_contractions(sent)
133
  conllu_output.append("# sent_id = {}-{}\n".format(name, i + 1))
134
  conllu_output.append("# text = {}\n".format(sent))
135
+ tokens, labels, scores = predict(sent, logger)
136
  for j, (token, label) in enumerate(zip(tokens, labels)):
137
  conllu_output.append(
138
  "{}\t{}\t_\t{}".format(j + 1, token, label) + "\t_" * 5 + "\n"
 
152
 
153
  with gr.Blocks(css=css) as demo:
154
  gr.HTML(top_html)
155
+ select_model = gr.Dropdown(choices=list(model_choices.keys()), label="Tagger model", value=default_model)
156
+ select_model.change(myapp.load_model, inputs=[select_model])
157
  with gr.Tab("Single sentence"):
158
  text = gr.Textbox(placeholder="Enter your text here...", label="Input")
159
  examples = gr.Examples(
 
202
  gr.HTML(bottom_html)
203
 
204
 
205
+ demo.launch(debug=True)