File size: 2,175 Bytes
83d8595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import tensorflow as tf 
import logging


def find_path(url):
    if url == '':
        return ''
    url = url.replace("-/-", "-")
    url_split = url.replace("https://", "")
    url_split = url_split.replace("www.", "")
    url_split = url_split.strip()
    url = url.replace("//", "/")
    url = url.replace("https/timesofindia-indiatimes-com", "")
    url_split = url_split.split("/")
    url_split = [u for u in url_split if (u != "") and
                                         (u != "articleshow") and
                                         (u.find(".cms")==-1) and
                                         (u.find(".ece")==-1) and
                                         (u.find(".htm")==-1) and
                                         (len(u.split('-')) <= 5) and
                                         (u.find(" ") == -1)
                ]
    if len(url_split) > 2:
        url_split = "/".join(url_split[1:])
    else:
        if len(url_split) > 0:
            url_split = url_split[-1]
        else:
            url_split = '-'
    return url_split


async def parse_prediction(tflite_pred, label_encoder):
    tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
    tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
    tflite_pred_prob = np.max(tflite_pred, axis=1)
    return tflite_pred_label, tflite_pred_prob
    

async def model_inference(text: list, calibrated_model, label_encoder):
    logging.info('Entering news_classifier.model_inference()')
    
    logging.info(f'Samples to predict: {len(text)}')
    if text != "":
        tflite_pred = calibrated_model.predict_proba(text)
        tflite_pred = await parse_prediction(tflite_pred, label_encoder)
    logging.info('Exiting news_classifier.model_inference()')
    return tflite_pred


async def predict_news_classes(urls: list, texts: list, calibrated_model, label_encoder):
    url_paths = [*map(find_path, urls)]
    paths_texts = [f"{p}. {t}" for p, t in zip(url_paths, texts)]
    label, prob = await model_inference(paths_texts, calibrated_model, label_encoder)
    return label, prob