kevin-yang commited on
Commit
b1e91c5
·
1 Parent(s): 4259675

add model cache and fix font

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -6,6 +6,8 @@ import seaborn
6
  import matplotlib
7
  import platform
8
 
 
 
9
  if platform.system() == "Darwin":
10
  print("MacOS")
11
  matplotlib.use('Agg')
@@ -14,20 +16,33 @@ import io
14
  from PIL import Image
15
 
16
  import matplotlib.font_manager as fm
 
17
 
18
 
 
 
 
 
 
19
 
 
 
 
 
 
 
20
 
21
- import util
22
-
23
- font_path = r'NanumGothicCoding.ttf'
24
- fontprop = fm.FontProperties(fname=font_path, size=18)
25
 
26
- plt.rcParams["font.family"] = 'NanumGothic'
 
 
 
 
27
 
28
 
29
  def visualize_attention(sent, attention_matrix, n_words=10):
30
  def draw(data, x, y, ax):
 
31
  seaborn.heatmap(data,
32
  xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
33
  cbar=False, ax=ax)
@@ -42,22 +57,27 @@ def visualize_attention(sent, attention_matrix, n_words=10):
42
 
43
  fig.tight_layout()
44
  plt.close()
45
-
46
  return fig
47
 
48
 
 
 
 
 
 
49
 
50
- def predict(model_name, text):
51
 
52
- tokenizer = AutoTokenizer.from_pretrained(model_name)
53
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
54
- config = AutoConfig.from_pretrained(model_name)
55
- print(config.id2label)
 
 
 
56
 
57
  tokenized_text = tokenizer([text], return_tensors='pt')
58
 
59
  input_tokens = tokenizer.convert_ids_to_tokens(tokenized_text.input_ids[0])
60
- print(input_tokens)
61
  input_tokens = util.bytetokens_to_unicdode(input_tokens) if config.model_type in ['roberta', 'gpt', 'gpt2'] else input_tokens
62
 
63
  model.eval()
@@ -73,12 +93,7 @@ def predict(model_name, text):
73
 
74
 
75
  if __name__ == '__main__':
76
-
77
- model_name = 'jason9693/SoongsilBERT-beep-base'
78
  text = '읿딴걸 홍볿글 읿랉곭 쌑젩낄고 앉앟있냩'
79
- # output = predict(model_name, text)
80
-
81
- # print(output)
82
 
83
  model_name_list = [
84
  'jason9693/SoongsilBERT-beep-base'
@@ -88,7 +103,7 @@ if __name__ == '__main__':
88
  app = gr.Interface(
89
  fn=predict,
90
  inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label', 'plot'],
91
- examples = [[model_name, text]],
92
  title="한국어 혐오성 발화 분류기 (Korean Hate Speech Classifier)",
93
  description="Korean Hate Speech Classifier with Several Pretrained LM\nCurrent Supported Model:\n1. SoongsilBERT"
94
  )
 
6
  import matplotlib
7
  import platform
8
 
9
+ from transformers.file_utils import ModelOutput
10
+
11
  if platform.system() == "Darwin":
12
  print("MacOS")
13
  matplotlib.use('Agg')
 
16
  from PIL import Image
17
 
18
  import matplotlib.font_manager as fm
19
+ import util
20
 
21
 
22
+ # global var
23
+ MODEL_NAME = 'jason9693/SoongsilBERT-beep-base'
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
26
+ config = AutoConfig.from_pretrained(MODEL_NAME)
27
 
28
+ MODEL_BUF = {
29
+ "name": MODEL_NAME,
30
+ "tokenizer": tokenizer,
31
+ "model": model,
32
+ "config": config
33
+ }
34
 
 
 
 
 
35
 
36
+ font_dir = ['./']
37
+ for font in fm.findSystemFonts(font_dir):
38
+ print(font)
39
+ fm.fontManager.addfont(font)
40
+ plt.rcParams["font.family"] = 'NanumGothicCoding'
41
 
42
 
43
  def visualize_attention(sent, attention_matrix, n_words=10):
44
  def draw(data, x, y, ax):
45
+
46
  seaborn.heatmap(data,
47
  xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
48
  cbar=False, ax=ax)
 
57
 
58
  fig.tight_layout()
59
  plt.close()
 
60
  return fig
61
 
62
 
63
+ def change_model_name(name):
64
+ MODEL_BUF["name"] = name
65
+ MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
66
+ MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
67
+ MODEL_BUF["config"] = AutoConfig.from_pretrained(name)
68
 
 
69
 
70
+ def predict(model_name, text):
71
+ if model_name != MODEL_NAME:
72
+ change_model_name(model_name)
73
+
74
+ tokenizer = MODEL_BUF["tokenizer"]
75
+ model = MODEL_BUF["model"]
76
+ config = MODEL_BUF["config"]
77
 
78
  tokenized_text = tokenizer([text], return_tensors='pt')
79
 
80
  input_tokens = tokenizer.convert_ids_to_tokens(tokenized_text.input_ids[0])
 
81
  input_tokens = util.bytetokens_to_unicdode(input_tokens) if config.model_type in ['roberta', 'gpt', 'gpt2'] else input_tokens
82
 
83
  model.eval()
 
93
 
94
 
95
  if __name__ == '__main__':
 
 
96
  text = '읿딴걸 홍볿글 읿랉곭 쌑젩낄고 앉앟있냩'
 
 
 
97
 
98
  model_name_list = [
99
  'jason9693/SoongsilBERT-beep-base'
 
103
  app = gr.Interface(
104
  fn=predict,
105
  inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label', 'plot'],
106
+ examples = [[MODEL_BUF["name"], text]],
107
  title="한국어 혐오성 발화 분류기 (Korean Hate Speech Classifier)",
108
  description="Korean Hate Speech Classifier with Several Pretrained LM\nCurrent Supported Model:\n1. SoongsilBERT"
109
  )