File size: 594 Bytes
05da059
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from pathlib import Path
from sentence_transformers.cross_encoder import CrossEncoder
from more_itertools import windowed
model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1', max_length=512, device='cpu')

def rerank(sentence_combinations):
  similarity_scores = model.predict(sentence_combinations)
  scores = [(score_max,idx) for idx,score_max in enumerate(similarity_scores)]
  sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
  return sim_scores_argsort

def search(query, sentences):
  scores = rerank([[query, s] for s in sentences])
  return scores[0]