from .exec_backends.trt_loader import TrtModel, encode as encode_trt from transformers import AutoTokenizer import math tokenizer_en = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-mpnet-base-v2") model_en = TrtModel("tensorRT/models/paraphrase-mpnet-base-v2.engine") tokenizer_multilingual = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2") model_multilingual= TrtModel("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2.engine") def encode(sentences, lang, batch_size = 8): if batch_size >=8: batch_size = 8 all_embs = [] NUM_BATCH = math.ceil(len(sentences) / batch_size) for j in range(NUM_BATCH): lst_sen = sentences[j*batch_size: j*batch_size + batch_size] if lang == 'en': # print(lst_sen) embs = encode_trt(lst_sen, tokenizer=tokenizer_en, trt_model= model_en, use_token_type_ids=False) else: # print(lst_sen) embs = encode_trt(lst_sen, tokenizer=tokenizer_multilingual, trt_model= model_multilingual, use_token_type_ids=False) all_embs.extend(embs) return all_embs