|
from typing import List |
|
from sentence_transformers import SentenceTransformer |
|
import os |
|
|
|
class PreTrainedPipeline: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize model |
|
""" |
|
self.model = SentenceTransformer(os.path.join(path)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs: str) -> List[float]: |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string to get the features of. |
|
Return: |
|
A :obj:`list` of floats: The features computed by the model. |
|
""" |
|
return self.model.encode(inputs).tolist() |
|
|
|
|
|
if __name__ == "__main__": |
|
xx = PreTrainedPipeline() |
|
|
|
print(xx.__call__("hei")) |