ksvmuralidhar commited on
Commit
b7bcf4d
1 Parent(s): d57f131

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -60,6 +60,7 @@ ner_model = TFAutoModelForTokenClassification.from_pretrained(NER_CHECKPOINT, nu
60
  ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
61
  ner_label_encoder = NERLabelEncoder()
62
  ner_label_encoder.fit()
 
63
  nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
64
  NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
65
  'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
@@ -107,9 +108,9 @@ def ner_inference(txt):
107
  Function that returns model prediction and prediction probabitliy
108
  '''
109
  test_data = [txt]
110
- tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
111
- tokens = tokenizer.tokenize(txt)
112
- tokenized_data = tokenizer(test_data, is_split_into_words=True, max_length=NER_N_TOKENS,
113
  truncation=True, padding="max_length")
114
 
115
  token_idx_to_consider = tokenized_data.word_ids()
@@ -169,6 +170,7 @@ def get_ner_text(article_txt, ner_result):
169
  SUMM_CHECKPOINT = "facebook/bart-base"
170
  SUMM_INPUT_N_TOKENS = 400
171
  SUMM_TARGET_N_TOKENS = 100
 
172
  summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
173
  summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
174
 
@@ -188,9 +190,9 @@ def summ_preprocess(txt):
188
  return txt
189
 
190
  def summ_inference_tokenize(input_: list, n_tokens: int):
191
- tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
192
- tokenized_data = tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
193
- return tokenizer, tokenized_data
194
 
195
  def summ_inference(txt: str):
196
  txt = summ_preprocess(txt)
 
60
  ner_model.load_weights(os.path.join("models", "general_ner_deberta_weights.h5"), by_name=True)
61
  ner_label_encoder = NERLabelEncoder()
62
  ner_label_encoder.fit()
63
+ ner_tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
64
  nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
65
  NER_COLOR_MAP = {'GEO': '#DFFF00', 'GPE': '#FFBF00', 'PER': '#9FE2BF',
66
  'ORG': '#40E0D0', 'TIM': '#CCCCFF', 'ART': '#FFC0CB', 'NAT': '#FFE4B5', 'EVE': '#DCDCDC'}
 
108
  Function that returns model prediction and prediction probabitliy
109
  '''
110
  test_data = [txt]
111
+ # tokenizer = DebertaTokenizerFast.from_pretrained(NER_CHECKPOINT, add_prefix_space=True)
112
+ tokens = ner_tokenizer.tokenize(txt)
113
+ tokenized_data = ner_tokenizer(test_data, is_split_into_words=True, max_length=NER_N_TOKENS,
114
  truncation=True, padding="max_length")
115
 
116
  token_idx_to_consider = tokenized_data.word_ids()
 
170
  SUMM_CHECKPOINT = "facebook/bart-base"
171
  SUMM_INPUT_N_TOKENS = 400
172
  SUMM_TARGET_N_TOKENS = 100
173
+ summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
174
  summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
175
  summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
176
 
 
190
  return txt
191
 
192
  def summ_inference_tokenize(input_: list, n_tokens: int):
193
+ # tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
194
+ tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
195
+ return summ_tokenizer, tokenized_data
196
 
197
  def summ_inference(txt: str):
198
  txt = summ_preprocess(txt)