|
from news_extractor import get_news |
|
from db_operations.db_write import DBWrite |
|
from db_operations.db_read import DBRead |
|
from news_category_similar_news_prediction import predict_news_category_similar_news |
|
import json |
|
from flask import Flask, Response |
|
from flask_cors import cross_origin, CORS |
|
import logging |
|
import tensorflow as tf |
|
import cloudpickle |
|
from transformers import DistilBertTokenizerFast |
|
import os |
|
from logger import get_logger |
|
import gc |
|
from find_similar_news import TextVectorizer, get_milvus_collection, load_sentence_transformer |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
logger = get_logger() |
|
logger.warning('Entering application') |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["TF_USE_LEGACY_KERAS"] = "1" |
|
|
|
def load_model(): |
|
logger.warning('Entering load transformer') |
|
with open("models/label_encoder.bin", "rb") as model_file_obj: |
|
label_encoder = cloudpickle.load(model_file_obj) |
|
|
|
with open("models/calibrated_model.bin", "rb") as model_file_obj: |
|
calibrated_model = cloudpickle.load(model_file_obj) |
|
|
|
tflite_model_path = os.path.join("models", "model.tflite") |
|
calibrated_model.estimator.tflite_model_path = tflite_model_path |
|
logger.warning('Exiting load transformer') |
|
return calibrated_model, label_encoder |
|
|
|
calibrated_model, label_encoder = load_model() |
|
vectorizer = TextVectorizer() |
|
collection = get_milvus_collection() |
|
sent_model, ce_model = load_sentence_transformer() |
|
|
|
|
|
@app.route("/") |
|
@cross_origin() |
|
def update_news(): |
|
logger.warning('Entering update_news()') |
|
status_json = "{'status':'success', 'message':'success'}" |
|
status_code = 200 |
|
try: |
|
db_read = DBRead() |
|
db_write = DBWrite(db_type="production") |
|
prediction_db_write = DBWrite(db_type="prediction") |
|
old_news = db_read.read_news_from_db() |
|
new_news = get_news() |
|
news_df, prediction_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, calibrated_model, label_encoder, |
|
collection, vectorizer, sent_model, ce_model) |
|
if news_df is None: |
|
raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_db_updation_required: |
|
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()] |
|
prediction_json = [*json.loads(prediction_df.reset_index(drop=True).to_json(orient="index")).values()] |
|
|
|
db_write.insert_news_into_db(news_json) |
|
prediction_db_write.insert_news_into_db(prediction_json) |
|
|
|
else: |
|
logger.warning('DB is not updated as it is not required.') |
|
except Exception as e: |
|
status_json = "{'status':'failure', 'message':'" + str(e) + "'}" |
|
status_code = 500 |
|
logger.warning(f'ERROR IN update_news(): {e}') |
|
|
|
logger.warning('Exiting update_news()') |
|
gc.collect() |
|
return Response(status_json, status=status_code, mimetype='application/json') |
|
|
|
logger.warning('Exiting application') |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, timeout=10000, workers=1, threads=1) |
|
|