from typing import Dict, List, Any from transformers import AutoModelForMaskedLM, AutoTokenizer import torch class EndpointHandler(): def __init__(self, path=""): tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForMaskedLM.from_pretrained(path) self.tokenizer = tokenizer self.model = model def __call__(self, data: Dict[str, Any]) -> List[Dict[Any, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs text = data.pop("text", data) tokens = self.tokenizer(text, return_tensors='pt', padding=True) outputs = self.model(**tokens) results = [] for idx, x in enumerate(outputs.logits): mask = tokens.attention_mask[idx] mask = mask[None,:] vec = torch.max( torch.log( 1 + torch.relu(x) ) * mask.unsqueeze(-1), dim=1)[0].squeeze() cols = vec.nonzero().squeeze().cpu().tolist() # extract the non-zero values weights = vec[cols].cpu().tolist() # use to create a dictionary of token ID to weight sparse_dict = dict(zip(map(str, cols), weights)) results.append(sparse_dict) return results