myrag / backend /cross_encoder.py
Adir Gozlan
late commit
6d5ec26
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
cross_encoder = None
cross_enc_tokenizer = None
TOP_K_RERANK = os.getenv("TOP_K_RERANK", 40)
@torch.no_grad()
def rerank_with_cross_encoder(cross_enc_name, documents, query):
if cross_enc_name is None or len(documents) <= 1:
return documents
global cross_encoder, cross_enc_tokenizer
if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
cross_encoder.eval()
cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
features = cross_enc_tokenizer(
[query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
)
scores = cross_encoder(**features).logits.squeeze()
ranks = torch.argsort(scores, descending=True)
documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
return documents