import numpy as np import tensorflow as tf import logging def find_path(url): if url == '': return '' url = url.replace("-/-", "-") url_split = url.replace("https://", "") url_split = url_split.replace("www.", "") url_split = url_split.strip() url = url.replace("//", "/") url = url.replace("https/timesofindia-indiatimes-com", "") url_split = url_split.split("/") url_split = [u for u in url_split if (u != "") and (u != "articleshow") and (u.find(".cms")==-1) and (u.find(".ece")==-1) and (u.find(".htm")==-1) and (len(u.split('-')) <= 5) and (u.find(" ") == -1) ] if len(url_split) > 2: url_split = "/".join(url_split[1:]) else: if len(url_split) > 0: url_split = url_split[-1] else: url_split = '-' return url_split async def parse_prediction(tflite_pred, label_encoder): tflite_pred_argmax = np.argmax(tflite_pred, axis=1) tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax) tflite_pred_prob = np.max(tflite_pred, axis=1) return tflite_pred_label, tflite_pred_prob async def model_inference(text: list, calibrated_model, label_encoder): logging.info('Entering news_classifier.model_inference()') logging.info(f'Samples to predict: {len(text)}') if text != "": tflite_pred = calibrated_model.predict_proba(text) tflite_pred = await parse_prediction(tflite_pred, label_encoder) logging.info('Exiting news_classifier.model_inference()') return tflite_pred async def predict_news_classes(urls: list, texts: list, calibrated_model, label_encoder): url_paths = [*map(find_path, urls)] paths_texts = [f"{p}. {t}" for p, t in zip(url_paths, texts)] label, prob = await model_inference(paths_texts, calibrated_model, label_encoder) return label, prob