|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
import torch
|
|
import numpy as np
|
|
from transformers import AutoTokenizer, is_torch_npu_available
|
|
from typing import Union, List
|
|
from .modeling import CrossEncoder
|
|
|
|
import os
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
|
|
|
|
def sigmoid(x):
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
|
|
class ListConRanker:
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
use_fp16: bool = False,
|
|
cache_dir: str = None,
|
|
device: Union[str, int] = None,
|
|
list_transformer_layer = None
|
|
) -> None:
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
self.model = CrossEncoder.from_pretrained_for_eval(model_name_or_path, list_transformer_layer)
|
|
|
|
if device and isinstance(device, str):
|
|
self.device = torch.device(device)
|
|
if device == 'cpu':
|
|
use_fp16 = False
|
|
else:
|
|
if torch.cuda.is_available():
|
|
if device is not None:
|
|
self.device = torch.device(f"cuda:{device}")
|
|
else:
|
|
self.device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
self.device = torch.device("mps")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
use_fp16 = False
|
|
if use_fp16:
|
|
self.model.half()
|
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.model.eval()
|
|
|
|
if device is None:
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
else:
|
|
self.num_gpus = 1
|
|
|
|
@torch.no_grad()
|
|
def compute_score(self, sentence_pairs: List[List[str]], max_length: int = 512) -> List[List[float]]:
|
|
pair_nums = [len(pairs) - 1 for pairs in sentence_pairs]
|
|
sentences_batch = sum(sentence_pairs, [])
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
).to(self.device)
|
|
inputs['pair_num'] = torch.LongTensor(pair_nums)
|
|
scores = self.model(inputs).float()
|
|
all_scores = scores.cpu().numpy().tolist()
|
|
|
|
if isinstance(all_scores, float):
|
|
return [all_scores]
|
|
result = []
|
|
curr_idx = 0
|
|
for i in range(len(pair_nums)):
|
|
result.append(all_scores[curr_idx: curr_idx + pair_nums[i]])
|
|
curr_idx += pair_nums[i]
|
|
|
|
return result
|
|
|
|
@torch.no_grad()
|
|
def iterative_inference(self, sentence_pairs: List[str], max_length: int = 512) -> List[float]:
|
|
query = sentence_pairs[0]
|
|
passage = sentence_pairs[1:]
|
|
|
|
filter_times = 0
|
|
passage2score = {}
|
|
while len(passage) > 20:
|
|
batch = [[query] + passage]
|
|
pred_scores = self.compute_score(batch, max_length)[0]
|
|
|
|
pred_scores_argsort = np.argsort(pred_scores).tolist()
|
|
passage_len = len(passage)
|
|
to_filter_num = math.ceil(passage_len * 0.2)
|
|
if to_filter_num < 10:
|
|
to_filter_num = 10
|
|
|
|
have_filter_num = 0
|
|
while have_filter_num < to_filter_num:
|
|
idx = pred_scores_argsort[have_filter_num]
|
|
if passage[idx] in passage2score:
|
|
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
|
else:
|
|
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
|
have_filter_num += 1
|
|
while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
|
|
idx = pred_scores_argsort[have_filter_num]
|
|
if passage[idx] in passage2score:
|
|
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
|
else:
|
|
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
|
have_filter_num += 1
|
|
next_passage = []
|
|
next_passage_idx = have_filter_num
|
|
while next_passage_idx < len(passage):
|
|
idx = pred_scores_argsort[next_passage_idx]
|
|
next_passage.append(passage[idx])
|
|
next_passage_idx += 1
|
|
passage = next_passage
|
|
filter_times += 1
|
|
|
|
batch = [[query] + passage]
|
|
pred_scores = self.compute_score(batch, max_length)[0]
|
|
cnt = 0
|
|
while cnt < len(passage):
|
|
if passage[cnt] in passage2score:
|
|
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
|
else:
|
|
passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
|
|
cnt += 1
|
|
|
|
passage = sentence_pairs[1:]
|
|
final_score = []
|
|
for i in range(len(passage)):
|
|
p = passage[i]
|
|
final_score += passage2score[p]
|
|
return final_score
|
|
|