|
from news_extractor import get_news |
|
from db_operations.db_write import DBWrite |
|
from db_operations.db_read import DBRead |
|
from news_category_prediction import predict_news_category |
|
import json |
|
from flask import Flask, Response |
|
from flask_cors import cross_origin, CORS |
|
import logging |
|
import tensorflow as tf |
|
import cloudpickle |
|
from transformers import DistilBertTokenizerFast |
|
import os |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
logging.warning('Initiated') |
|
|
|
|
|
def load_model(): |
|
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite")) |
|
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj: |
|
label_encoder = cloudpickle.load(model_file_obj) |
|
|
|
model_checkpoint = "distilbert-base-uncased" |
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) |
|
return interpreter, label_encoder, tokenizer |
|
|
|
interpreter, label_encoder, tokenizer = load_model() |
|
|
|
|
|
@app.route("/") |
|
@cross_origin() |
|
def update_news(): |
|
status_json = "{'status':'success'}" |
|
status_code = 200 |
|
try: |
|
db_read = DBRead() |
|
db_write = DBWrite() |
|
old_news = db_read.read_news_from_db() |
|
new_news = get_news() |
|
news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer) |
|
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()] |
|
db_write.insert_news_into_db(news_json) |
|
except: |
|
status_json = "{'status':'failure'}" |
|
status_code = 500 |
|
return Response(status_json, status=status_code, mimetype='application/json') |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, timeout=120, workers=1, threads=1) |
|
|