ksvmuralidhar commited on
Commit
b613afc
1 Parent(s): c32a726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -56,14 +56,22 @@ class NERLabelEncoder:
56
  NER_CHECKPOINT = "microsoft/deberta-base"
57
  NER_N_TOKENS = 50
58
  NER_N_LABELS = 18
59
- ner_model = TFAutoModelForTokenClassification.from_pretrained(NER_CHECKPOINT, num_labels=NER_N_LABELS, attention_probs_dropout_prob=0.4, hidden_dropout_prob=0.4)
60
- ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
61
- ner_label_encoder = NERLabelEncoder()
62
- ner_label_encoder.fit()
63
- ner_tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
64
- nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
65
  NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
66
  'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ############ NER MODEL & VARS INITIALIZATION END ####################
68
 
69
  ############ NER LOGIC START ####################
@@ -170,9 +178,16 @@ def get_ner_text(article_txt, ner_result):
170
  SUMM_CHECKPOINT = "facebook/bart-base"
171
  SUMM_INPUT_N_TOKENS = 400
172
  SUMM_TARGET_N_TOKENS = 100
173
- summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
174
- summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
175
- summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
 
 
 
 
 
 
 
176
 
177
  def summ_preprocess(txt):
178
  txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
@@ -190,7 +205,6 @@ def summ_preprocess(txt):
190
  return txt
191
 
192
  def summ_inference_tokenize(input_: list, n_tokens: int):
193
- # tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
194
  tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
195
  return summ_tokenizer, tokenized_data
196
 
@@ -207,7 +221,7 @@ def summ_inference(txt: str):
207
  ############## ENTRY POINT START #######################
208
  def main():
209
  st.title("News Summarizer & NER")
210
- article_txt = st.text_area("Paste the text of a news article:", "", height=200)
211
  if st.button("Submit"):
212
  ner_result = [[ent, label.upper(), np.round(prob, 3)]
213
  for ent, label, prob in ner_inference_long_text(article_txt)]
 
56
  NER_CHECKPOINT = "microsoft/deberta-base"
57
  NER_N_TOKENS = 50
58
  NER_N_LABELS = 18
 
 
 
 
 
 
59
  NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
60
  'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
61
+
62
+ @st.cache_resource
63
+ def load_ner_models():
64
+ ner_model = TFAutoModelForTokenClassification.from_pretrained(NER_CHECKPOINT, num_labels=NER_N_LABELS, attention_probs_dropout_prob=0.4, hidden_dropout_prob=0.4)
65
+ ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
66
+ ner_label_encoder = NERLabelEncoder()
67
+ ner_label_encoder.fit()
68
+ ner_tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
69
+ nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
70
+ print('Loaded NER models')
71
+ return ner_model, ner_label_encoder, ner_tokenizer, nlp
72
+
73
+ ner_model, ner_label_encoder, ner_tokenizer, nlp = load_ner_models()
74
+
75
  ############ NER MODEL & VARS INITIALIZATION END ####################
76
 
77
  ############ NER LOGIC START ####################
 
178
  SUMM_CHECKPOINT = "facebook/bart-base"
179
  SUMM_INPUT_N_TOKENS = 400
180
  SUMM_TARGET_N_TOKENS = 100
181
+
182
+ @st.cache_resource
183
+ def load_summarizer_models():
184
+ summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
185
+ summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
186
+ summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
187
+ print('Loaded summarizer models')
188
+ return summ_tokenizer, summ_model
189
+
190
+ summ_tokenizer, summ_model = load_summarizer_models()
191
 
192
  def summ_preprocess(txt):
193
  txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
 
205
  return txt
206
 
207
  def summ_inference_tokenize(input_: list, n_tokens: int):
 
208
  tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
209
  return summ_tokenizer, tokenized_data
210
 
 
221
  ############## ENTRY POINT START #######################
222
  def main():
223
  st.title("News Summarizer & NER")
224
+ article_txt = st.text_area("Paste few sentences of a news article:", "", height=200)
225
  if st.button("Submit"):
226
  ner_result = [[ent, label.upper(), np.round(prob, 3)]
227
  for ent, label, prob in ner_inference_long_text(article_txt)]