File size: 3,774 Bytes
83d8595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from sklearn.dummy import DummyClassifier
from tqdm import tqdm
import multiprocessing
import numpy as np
import tensorflow as tf
from transformers import DistilBertTokenizerFast



class PredictProba(DummyClassifier):
    def __init__(self, tflite_model_path: str, classes_: list, n_tokens: int):
        self.classes_ = classes_ # required attribute for an estimator to be used in calibration classifier
        self.n_tokens = n_tokens
        self.tflite_model_path = tflite_model_path


    def fit(self, x, y):
        print('called fit')
        return self # fit method is required for an estimator to be used in calibration classifier

    @staticmethod
    def get_token_batches(attention_mask, input_ids, batch_size: int=8):
        n_texts = len(attention_mask)
        n_batches = int(np.ceil(n_texts / batch_size))
        if n_texts <= batch_size:
            n_batches = 1

        attention_mask_batches = []
        input_ids_batches = []

        for i in range(n_batches):
            if i != n_batches-1:
                attention_mask_batches.append(attention_mask[i*batch_size: batch_size*(i+1)])
                input_ids_batches.append(input_ids[i*batch_size: batch_size*(i+1)])
            else:
                attention_mask_batches.append(attention_mask[i*batch_size:])
                input_ids_batches.append(input_ids[i*batch_size:])
        
        return attention_mask_batches, input_ids_batches
        
    
    def get_batch_inference(self, batch_size, attention_mask, input_ids):
        interpreter = tf.lite.Interpreter(model_path=self.tflite_model_path)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()[0]
        interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, self.n_tokens])
        interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, self.n_tokens])
        interpreter.resize_tensor_input(output_details['index'],[batch_size, len(self.classes_)])
        interpreter.allocate_tensors()
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_pred = interpreter.get_tensor(output_details["index"])
        return tflite_pred
    
    def inference(self, texts):
        model_checkpoint = "distilbert-base-uncased"
        tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
        tokens = tokenizer(texts, max_length=self.n_tokens, padding="max_length", 
                           truncation=True, return_tensors="tf")
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        attention_mask_batches, input_ids_batches = self.get_token_batches(attention_mask, input_ids)
        
        
        
        pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
        results = []
        for attention_mask, input_ids in zip(attention_mask_batches, input_ids_batches):
            f = pool.apply_async(self.get_batch_inference, args=(len(attention_mask), attention_mask, input_ids))
            results.append(f)
        
        all_predictions = np.array([])
        for n_batch in tqdm(range(len(results))):
            tflite_pred = results[n_batch].get(timeout=360)
            if n_batch == 0:
                all_predictions = tflite_pred
            else:
                all_predictions = np.concatenate((all_predictions, tflite_pred), axis=0)
        return all_predictions

    def predict_proba(self, X, y=None):
        predict_prob = self.inference(X)
        return predict_prob