news_classifier_api / news_classifier.py
ksvmuralidhar's picture
Upload 10 files
83d8595 verified
raw
history blame
2.18 kB
import numpy as np
import tensorflow as tf
import logging
def find_path(url):
if url == '':
return ''
url = url.replace("-/-", "-")
url_split = url.replace("https://", "")
url_split = url_split.replace("www.", "")
url_split = url_split.strip()
url = url.replace("//", "/")
url = url.replace("https/timesofindia-indiatimes-com", "")
url_split = url_split.split("/")
url_split = [u for u in url_split if (u != "") and
(u != "articleshow") and
(u.find(".cms")==-1) and
(u.find(".ece")==-1) and
(u.find(".htm")==-1) and
(len(u.split('-')) <= 5) and
(u.find(" ") == -1)
]
if len(url_split) > 2:
url_split = "/".join(url_split[1:])
else:
if len(url_split) > 0:
url_split = url_split[-1]
else:
url_split = '-'
return url_split
async def parse_prediction(tflite_pred, label_encoder):
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
tflite_pred_prob = np.max(tflite_pred, axis=1)
return tflite_pred_label, tflite_pred_prob
async def model_inference(text: list, calibrated_model, label_encoder):
logging.info('Entering news_classifier.model_inference()')
logging.info(f'Samples to predict: {len(text)}')
if text != "":
tflite_pred = calibrated_model.predict_proba(text)
tflite_pred = await parse_prediction(tflite_pred, label_encoder)
logging.info('Exiting news_classifier.model_inference()')
return tflite_pred
async def predict_news_classes(urls: list, texts: list, calibrated_model, label_encoder):
url_paths = [*map(find_path, urls)]
paths_texts = [f"{p}. {t}" for p, t in zip(url_paths, texts)]
label, prob = await model_inference(paths_texts, calibrated_model, label_encoder)
return label, prob