from functools import lru_cache import numpy as np import torch from sentence_transformers import SentenceTransformer DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' list_models = [ 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', 'cyclone/simcse-chinese-roberta-wwm-ext' ] class SBert: def __init__(self, path): print(f'Loading model from {path} ...') self.model = SentenceTransformer(path, device=DEVICE) @lru_cache(maxsize=10000) def __call__(self, x) -> np.ndarray: y = self.model.encode(x) return y