Spaces:
Runtime error
Runtime error
File size: 3,483 Bytes
d6585f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import scipy
from pyserini.encode import QueryEncoder
class SlimQueryEncoder(QueryEncoder):
def __init__(self, model_name_or_path, tokenizer_name=None, fusion_weight=.99, device='cpu'):
self.device = device
self.fusion_weight = fusion_weight
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
def encode(self, text, max_length=256, topk=20, return_sparse=False, **kwargs):
inputs = self.tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
add_special_tokens=True,
)
outputs = self.model(**inputs, return_dict=True)
attention_mask = inputs["attention_mask"][:, 1:] # remove the cls token
logits = outputs.logits[:, 1:, :] # remove the cls token prediction
# routing, assign every token to top-k expert
full_router_repr = torch.log(1 + torch.relu(logits)) * attention_mask.unsqueeze(-1)
expert_weights, expert_ids = torch.topk(full_router_repr, dim=2, k=topk) # B x T x topk
min_expert_weight = torch.min(expert_weights, -1, True)[0]
sparse_expert_weights = torch.where(full_router_repr >= min_expert_weight, full_router_repr, 0)
return self._output_to_weight_dicts(expert_weights.cpu(), expert_ids.cpu(), sparse_expert_weights.cpu(), attention_mask.cpu(), return_sparse)[0]
def _output_to_weight_dicts(self, batch_expert_weights, batch_expert_ids, batch_sparse_expert_weights, batch_attention, return_sparse):
to_return = []
for batch_id, sparse_expert_weights in enumerate(batch_sparse_expert_weights):
tok_vector = scipy.sparse.csr_matrix(sparse_expert_weights.detach().numpy())
upper_vector, lower_vector = {}, {}
max_term, max_weight = None, 0
for position, (expert_topk_ids, expert_topk_weights, attention_score) in enumerate(zip(batch_expert_ids[batch_id],
batch_expert_weights[batch_id],
batch_attention[batch_id])):
if attention_score > 0:
for expert_id, expert_weight in zip(expert_topk_ids, expert_topk_weights):
if expert_weight > 0:
term, weight = self.reverse_vocab[expert_id.item()], expert_weight.item()
upper_vector[term] = upper_vector.get(term, 0) + weight
if weight > max_weight:
max_term, max_weight = term, weight
if max_term is not None:
lower_vector[term] = lower_vector.get(term, 0) + weight
fusion_vector = {}
for term, weight in upper_vector.items():
fusion_vector[term] = self.fusion_weight * weight + (1 - self.fusion_weight) * lower_vector.get(term, 0)
if return_sparse:
to_return.append((fusion_vector, tok_vector))
else:
to_return.append(fusion_vector)
return to_return |