Spaces:
Runtime error
Runtime error
from functools import lru_cache | |
import torch | |
from loguru import logger | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModel | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
list_models = [ | |
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', | |
'sentence-transformers/all-mpnet-base-v2', | |
'sentence-transformers/all-MiniLM-L12-v2', | |
'cyclone/simcse-chinese-roberta-wwm-ext', | |
'bert-base-chinese', | |
'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', | |
] | |
class SBert: | |
def __init__(self, path): | |
logger.info(f'Start loading {self.__class__} from {path} ...') | |
self.model = SentenceTransformer(path, device=DEVICE) | |
logger.info(f'Load {self.__class__} from {path} ...') | |
def __call__(self, x) -> torch.Tensor: | |
y = self.model.encode(x, convert_to_tensor=True) | |
return y | |
class ModelWithPooling: | |
def __init__(self, path): | |
logger.info(f'Start loading {self.__class__} from {path} ...') | |
self.tokenizer = AutoTokenizer.from_pretrained(path) | |
self.model = AutoModel.from_pretrained(path) | |
logger.info(f'Load {self.__class__} from {path} ...') | |
def __call__(self, text: str, pooling='mean'): | |
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
outputs = self.model(**inputs, output_hidden_states=True) | |
if pooling == 'cls': | |
o = outputs.last_hidden_state[:, 0] # [b, h] | |
elif pooling == 'pooler': | |
o = outputs.pooler_output # [b, h] | |
elif pooling in ['mean', 'last-avg']: | |
last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s] | |
o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
elif pooling == 'first-last-avg': | |
first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s] | |
last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s] | |
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h] | |
o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h] | |
else: | |
raise Exception(f'Unknown pooling {pooling}') | |
o = o.squeeze(0) | |
return o | |
def test_sbert(): | |
m = SBert('bert-base-chinese') | |
o = m('hello') | |
print(o.size()) | |
assert o.size() == (768,) | |
def test_hf_model(): | |
m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') | |
o = m('hello', pooling='cls') | |
print(o.size()) | |
assert o.size() == (768,) | |