from typing import Dict, List, Any from transformers import AutoModelForMaskedLM, AutoTokenizer import torch class EndpointHandler(): def __init__(self, path=""): tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForMaskedLM.from_pretrained(path) self.tokenizer = tokenizer self.model = model def __call__(self, data: Dict[str, Any]) -> List[Dict[Any, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs text = data.pop("text", data) tokens = self.tokenizer(text, return_tensors='pt') output = self.model(**tokens) vec = torch.max( torch.log( 1 + torch.relu(output.logits) ) * tokens.attention_mask.unsqueeze(-1), dim=1)[0].squeeze() cols = vec.nonzero().squeeze().cpu().tolist() # extract the non-zero values weights = vec[cols].cpu().tolist() # use to create a dictionary of token ID to weight sparse_dict = dict(zip(map(str, cols), weights)) return sparse_dict