File size: 806 Bytes
02a3457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import List
from sentence_transformers import SentenceTransformer
import os

class PreTrainedPipeline:
    def __init__(self, path=""):
        """
        Initialize model
        """
        self.model = SentenceTransformer(os.path.join(path))
            # os.path.join(path, 'quora-distilbert-multilingual')
           #"sentence-transformers/quora-distilbert-multilingual"
        #)




    def __call__(self, inputs: str) -> List[float]:
        """
        Args:
            inputs (:obj:`str`):
                a string to get the features of.
        Return:
            A :obj:`list` of floats: The features computed by the model.
        """
        return self.model.encode(inputs).tolist()

#
if __name__ == "__main__":
    xx = PreTrainedPipeline()

    print(xx.__call__("hei"))