File size: 876 Bytes
63775f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
TensorFlow Model Wrapper
--------------------------
"""


import numpy as np

from .model_wrapper import ModelWrapper


class TensorFlowModelWrapper(ModelWrapper):
    """Loads a TensorFlow model and tokenizer.

    TensorFlow models can use many different architectures and
    tokenization strategies. This assumes that the model takes an
    np.array of strings as input and returns a tf.Tensor of outputs, as
    is typical with Keras modules. You may need to subclass this for
    models that have dedicated tokenizers or otherwise take input
    differently.
    """

    def __init__(self, model):
        self.model = model

    def __call__(self, text_input_list, **kwargs):
        text_array = np.array(text_input_list)
        preds = self.model(text_array)
        return preds.numpy()

    def get_grad(self, text_input):
        raise NotImplementedError()