Spaces:
Sleeping
Sleeping
""" | |
Utilities for interfacing with the attentions from the front end. | |
""" | |
import torch | |
from typing import List, Union | |
from abc import ABC, abstractmethod | |
from transformer_formatter import TransformerOutputFormatter | |
from utils.token_processing import reshape | |
from spacyface import ( | |
BertAligner, | |
GPT2Aligner, | |
RobertaAligner, | |
DistilBertAligner, | |
auto_aligner | |
) | |
from transformers import ( | |
BertForMaskedLM, | |
GPT2LMHeadModel, | |
RobertaForMaskedLM, | |
DistilBertForMaskedLM, | |
) | |
from utils.f import delegates, pick, memoize | |
def get_cls(class_name): | |
cls_type = { | |
'bert-base-uncased': BertDetails, | |
'bert-base-cased': BertDetails, | |
'bert-large-uncased': BertDetails, | |
'bert-large-cased': BertDetails, | |
'gpt2': GPT2Details, | |
'gpt2-medium': GPT2Details, | |
'gpt2-large': GPT2Details, | |
'roberta-base': RobertaDetails, | |
'roberta-large': RobertaDetails, | |
'roberta-large-mnli': RobertaDetails, | |
'roberta-base-openai-detector': RobertaDetails, | |
'roberta-large-openai-detector': RobertaDetails, | |
'distilbert-base-uncased': DistilBertDetails, | |
'distilbert-base-uncased-distilled-squad': DistilBertDetails, | |
'distilgpt2': GPT2Details, | |
'distilroberta-base': RobertaDetails, | |
} | |
return cls_type[class_name] | |
def from_pretrained(model_name): | |
"""Convert model name into appropriate transformer details""" | |
try: out = get_cls(model_name).from_pretrained(model_name) | |
except KeyError: raise KeyError(f"The model name of '{model_name}' either does not exist or is currently not supported") | |
return out | |
class TransformerBaseDetails(ABC): | |
""" All API calls will interact with this class to get the hidden states and attentions for any input sentence.""" | |
def __init__(self, model, aligner): | |
self.model = model | |
self.aligner = aligner | |
self.model.eval() | |
self.forward_inputs = ['input_ids', 'attention_mask'] | |
def from_pretrained(cls, model_name: str): | |
raise NotImplementedError( | |
"""Inherit from this class and specify the Model and Aligner to use""" | |
) | |
def att_from_sentence(self, s: str, mask_attentions=False) -> TransformerOutputFormatter: | |
"""Get formatted attention from a single sentence input""" | |
tokens = self.aligner.tokenize(s) | |
return self.att_from_tokens(tokens, s, add_special_tokens=True, mask_attentions=mask_attentions) | |
def att_from_tokens( | |
self, tokens: List[str], orig_sentence, add_special_tokens=False, mask_attentions=False | |
) -> TransformerOutputFormatter: | |
"""Get formatted attention from a list of tokens, using the original sentence for getting Spacy Metadata""" | |
ids = self.aligner.convert_tokens_to_ids(tokens) | |
# For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM | |
bost = self.aligner.bos_token_id | |
clst = self.aligner.cls_token_id | |
if (bost is not None) and (bost != clst) and add_special_tokens: | |
ids.insert(0, bost) | |
inputs = self.aligner.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt") | |
parsed_input = self.format_model_input(inputs, mask_attentions=mask_attentions) | |
output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask']) | |
return self.format_model_output(inputs, orig_sentence, output) | |
def format_model_output(self, inputs, sentence:str, output, topk=5): | |
"""Convert model output to the desired format. | |
Formatter additionally needs access to the tokens and the original sentence | |
""" | |
hidden_state, attentions, contexts, logits = self.select_outputs(output) | |
words, probs = self.logits2words(logits, topk) | |
tokens = self.view_ids(inputs["input_ids"]) | |
toks = self.aligner.meta_from_tokens(sentence, tokens, perform_check=False) | |
formatted_output = TransformerOutputFormatter( | |
sentence, | |
toks, | |
inputs["special_tokens_mask"], | |
attentions, | |
hidden_state, | |
contexts, | |
words, | |
probs.tolist() | |
) | |
return formatted_output | |
def select_outputs(self, output): | |
"""Extract the desired hidden states as passed by a particular model through the output | |
In all cases, we care for: | |
- hidden state embeddings (tuple of n_layers + 1) | |
- attentions (tuple of n_layers) | |
- contexts (tuple of n_layers) | |
- Top predicted words | |
- Probabilities of top predicted words | |
""" | |
logits, hidden_state, attentions, contexts = output | |
return hidden_state, attentions, contexts, logits | |
def format_model_input(self, inputs, mask_attentions=False): | |
"""Parse the input for the model according to what is expected in the forward pass. | |
If not otherwise defined, outputs a dict containing the keys: | |
{'input_ids', 'attention_mask'} | |
""" | |
return pick(self.forward_inputs, self.parse_inputs(inputs, mask_attentions=mask_attentions)) | |
def logits2words(self, logits, topk=5): | |
probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk) | |
words = [self.aligner.convert_ids_to_tokens(i) for i in idxs] | |
return words, probs | |
def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]: | |
"""View what the tokenizer thinks certain ids are""" | |
if type(ids) == torch.Tensor: | |
# Remove batch dimension | |
ids = ids.squeeze(0).tolist() | |
out = self.aligner.convert_ids_to_tokens(ids) | |
return out | |
def parse_inputs(self, inputs, mask_attentions=False): | |
"""Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens | |
Args: | |
- inputs: The output of `tokenizer.prepare_for_model`. | |
A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'} | |
- mask_attentions: Flag indicating whether to mask the attentions or not | |
Returns: | |
Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'} | |
Usage: | |
``` | |
s = "test sentence" | |
# from raw sentence to tokens | |
tokens = tokenizer.tokenize(s) | |
# From tokens to ids | |
ids = tokenizer.convert_tokens_to_ids(tokens) | |
# From ids to input | |
inputs = tokenizer.prepare_for_model(ids, return_tensors='pt') | |
# Parse the input. Optionally mask the special tokens from the analysis. | |
parsed_input = parse_inputs(inputs) | |
# Run the model, pick from this output whatever inputs you want | |
from utils.f import pick | |
out = model(**pick(['input_ids'], parse_inputs(inputs))) | |
``` | |
""" | |
out = inputs.copy() | |
# DEFINE SPECIAL TOKENS MASK | |
if "special_tokens_mask" not in inputs.keys(): | |
special_tokens = set([self.aligner.unk_token_id, self.aligner.cls_token_id, self.aligner.sep_token_id, self.aligner.bos_token_id, self.aligner.eos_token_id, self.aligner.pad_token_id]) | |
in_ids = inputs['input_ids'][0] | |
special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids] | |
inputs['special_tokens_mask'] = special_tok_mask | |
if mask_attentions: | |
out["attention_mask"] = torch.tensor( | |
[int(not i) for i in inputs.get("special_tokens_mask")] | |
).unsqueeze(0) | |
else: | |
out["attention_mask"] = torch.tensor( | |
[1 for i in inputs.get("special_tokens_mask")] | |
).unsqueeze(0) | |
return out | |
class BertDetails(TransformerBaseDetails): | |
def from_pretrained(cls, model_name: str): | |
return cls( | |
BertForMaskedLM.from_pretrained( | |
model_name, | |
output_attentions=True, | |
output_hidden_states=True, | |
output_additional_info=True, | |
), | |
BertAligner.from_pretrained(model_name), | |
) | |
class GPT2Details(TransformerBaseDetails): | |
def from_pretrained(cls, model_name: str): | |
return cls( | |
GPT2LMHeadModel.from_pretrained( | |
model_name, | |
output_attentions=True, | |
output_hidden_states=True, | |
output_additional_info=True, | |
), | |
GPT2Aligner.from_pretrained(model_name), | |
) | |
def select_outputs(self, output): | |
logits, _ , hidden_states, att, contexts = output | |
return hidden_states, att, contexts, logits | |
class RobertaDetails(TransformerBaseDetails): | |
def from_pretrained(cls, model_name: str): | |
return cls( | |
RobertaForMaskedLM.from_pretrained( | |
model_name, | |
output_attentions=True, | |
output_hidden_states=True, | |
output_additional_info=True, | |
), | |
RobertaAligner.from_pretrained(model_name), | |
) | |
class DistilBertDetails(TransformerBaseDetails): | |
def __init__(self, model, aligner): | |
super().__init__(model, aligner) | |
self.forward_inputs = ['input_ids', 'attention_mask'] | |
def from_pretrained(cls, model_name: str): | |
return cls( | |
DistilBertForMaskedLM.from_pretrained( | |
model_name, | |
output_attentions=True, | |
output_hidden_states=True, | |
output_additional_info=True, | |
), | |
DistilBertAligner.from_pretrained(model_name), | |
) | |
def select_outputs(self, output): | |
"""Extract the desired hidden states as passed by a particular model through the output | |
In all cases, we care for: | |
- hidden state embeddings (tuple of n_layers + 1) | |
- attentions (tuple of n_layers) | |
- contexts (tuple of n_layers) | |
""" | |
logits, hidden_states, attentions, contexts = output | |
contexts = tuple([c.permute(0, 2, 1, 3).contiguous() for c in contexts]) | |
return hidden_states, attentions, contexts, logits |