Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
import numpy as np | |
from pyserini.encode import QueryEncoder | |
class SpladeQueryEncoder(QueryEncoder): | |
def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'): | |
self.device = device | |
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) | |
self.model.to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) | |
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} | |
def encode(self, text, max_length=256, **kwargs): | |
inputs = self.tokenizer([text], max_length=max_length, padding='longest', | |
truncation=True, add_special_tokens=True, | |
return_tensors='pt').to(self.device) | |
input_ids = inputs['input_ids'] | |
input_attention = inputs['attention_mask'] | |
batch_logits = self.model(input_ids)['logits'] | |
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) | |
* input_attention.unsqueeze(-1), dim=1) | |
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() | |
return self._output_to_weight_dicts(batch_aggregated_logits)[0] | |
def _output_to_weight_dicts(self, batch_aggregated_logits): | |
to_return = [] | |
for aggregated_logits in batch_aggregated_logits: | |
col = np.nonzero(aggregated_logits)[0] | |
weights = aggregated_logits[col] | |
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))} | |
to_return.append(d) | |
return to_return | |