ListConRanker / modules /listconranker.py
ianma2024's picture
upload codes and model weights
dd0b4f3
raw
history blame
6.57 kB
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the “Software”), to deal in the Software without
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
import math
import torch
import numpy as np
from transformers import AutoTokenizer, is_torch_npu_available
from typing import Union, List
from .modeling import CrossEncoder
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class ListConRanker:
def __init__(
self,
model_name_or_path: str = None,
use_fp16: bool = False,
cache_dir: str = None,
device: Union[str, int] = None,
list_transformer_layer = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.model = CrossEncoder.from_pretrained_for_eval(model_name_or_path, list_transformer_layer)
if device and isinstance(device, str):
self.device = torch.device(device)
if device == 'cpu':
use_fp16 = False
else:
if torch.cuda.is_available():
if device is not None:
self.device = torch.device(f"cuda:{device}")
else:
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
if use_fp16:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()
if device is None:
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------using {self.num_gpus}*GPUs----------")
self.model = torch.nn.DataParallel(self.model)
else:
self.num_gpus = 1
@torch.no_grad()
def compute_score(self, sentence_pairs: List[List[str]], max_length: int = 512) -> List[List[float]]:
pair_nums = [len(pairs) - 1 for pairs in sentence_pairs]
sentences_batch = sum(sentence_pairs, [])
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
).to(self.device)
inputs['pair_num'] = torch.LongTensor(pair_nums)
scores = self.model(inputs).float()
all_scores = scores.cpu().numpy().tolist()
if isinstance(all_scores, float):
return [all_scores]
result = []
curr_idx = 0
for i in range(len(pair_nums)):
result.append(all_scores[curr_idx: curr_idx + pair_nums[i]])
curr_idx += pair_nums[i]
# return all_scores
return result
@torch.no_grad()
def iterative_inference(self, sentence_pairs: List[str], max_length: int = 512) -> List[float]:
query = sentence_pairs[0]
passage = sentence_pairs[1:]
filter_times = 0
passage2score = {}
while len(passage) > 20:
batch = [[query] + passage]
pred_scores = self.compute_score(batch, max_length)[0]
# Sort in increasing order
pred_scores_argsort = np.argsort(pred_scores).tolist()
passage_len = len(passage)
to_filter_num = math.ceil(passage_len * 0.2)
if to_filter_num < 10:
to_filter_num = 10
have_filter_num = 0
while have_filter_num < to_filter_num:
idx = pred_scores_argsort[have_filter_num]
if passage[idx] in passage2score:
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
else:
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
have_filter_num += 1
while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
idx = pred_scores_argsort[have_filter_num]
if passage[idx] in passage2score:
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
else:
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
have_filter_num += 1
next_passage = []
next_passage_idx = have_filter_num
while next_passage_idx < len(passage):
idx = pred_scores_argsort[next_passage_idx]
next_passage.append(passage[idx])
next_passage_idx += 1
passage = next_passage
filter_times += 1
batch = [[query] + passage]
pred_scores = self.compute_score(batch, max_length)[0]
cnt = 0
while cnt < len(passage):
if passage[cnt] in passage2score:
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
else:
passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
cnt += 1
passage = sentence_pairs[1:]
final_score = []
for i in range(len(passage)):
p = passage[i]
final_score += passage2score[p]
return final_score