lalithadevi commited on
Commit
6bbd283
·
verified ·
1 Parent(s): f4d9674

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -19,20 +19,22 @@ CORS(app)
19
 
20
  logger = get_logger()
21
  logger.warning('Entering application')
22
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
 
24
  def load_model():
25
  logger.warning('Entering load transformer')
26
- interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
27
- with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
28
  label_encoder = cloudpickle.load(model_file_obj)
29
-
30
- model_checkpoint = "distilbert-base-uncased"
31
- tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
 
 
 
32
  logger.warning('Exiting load transformer')
33
- return interpreter, label_encoder, tokenizer
34
 
35
- interpreter, label_encoder, tokenizer = load_model()
36
  vectorizer = TextVectorizer()
37
  collection = get_milvus_collection()
38
  sent_model, ce_model = load_sentence_transformer()
@@ -50,8 +52,8 @@ def update_news():
50
  prediction_db_write = DBWrite(db_type="prediction")
51
  old_news = db_read.read_news_from_db()
52
  new_news = get_news()
53
- news_df, prediction_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, interpreter, label_encoder,
54
- tokenizer, collection, vectorizer, sent_model, ce_model)
55
  if news_df is None:
56
  raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
57
 
 
19
 
20
  logger = get_logger()
21
  logger.warning('Entering application')
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
  def load_model():
25
  logger.warning('Entering load transformer')
26
+ with open("models/label_encoder.bin", "rb") as model_file_obj:
 
27
  label_encoder = cloudpickle.load(model_file_obj)
28
+
29
+ with open("models/calibrated_model.bin", "rb") as model_file_obj:
30
+ calibrated_model = cloudpickle.load(model_file_obj)
31
+
32
+ tflite_model_path = os.path.join("models", "model.tflite")
33
+ calibrated_model.estimator.tflite_model_path = tflite_model_path
34
  logger.warning('Exiting load transformer')
35
+ return calibrated_model, label_encoder
36
 
37
+ calibrated_model, label_encoder = load_model()
38
  vectorizer = TextVectorizer()
39
  collection = get_milvus_collection()
40
  sent_model, ce_model = load_sentence_transformer()
 
52
  prediction_db_write = DBWrite(db_type="prediction")
53
  old_news = db_read.read_news_from_db()
54
  new_news = get_news()
55
+ news_df, prediction_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, calibrated_model, label_encoder,
56
+ collection, vectorizer, sent_model, ce_model)
57
  if news_df is None:
58
  raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
59