geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from pyserini.encode import QueryEncoder
class SpladeQueryEncoder(QueryEncoder):
def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.device = device
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}
def encode(self, text, max_length=256, **kwargs):
inputs = self.tokenizer([text], max_length=max_length, padding='longest',
truncation=True, add_special_tokens=True,
return_tensors='pt').to(self.device)
input_ids = inputs['input_ids']
input_attention = inputs['attention_mask']
batch_logits = self.model(input_ids)['logits']
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits))
* input_attention.unsqueeze(-1), dim=1)
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_aggregated_logits)[0]
def _output_to_weight_dicts(self, batch_aggregated_logits):
to_return = []
for aggregated_logits in batch_aggregated_logits:
col = np.nonzero(aggregated_logits)[0]
weights = aggregated_logits[col]
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))}
to_return.append(d)
return to_return