Spaces:
Sleeping
Sleeping
import torch | |
from . import utils, metrics | |
class ModelWrapper: | |
""" | |
PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers | |
experiments. | |
""" | |
def __init__(self, model, tokenizer): | |
self._model = model | |
self._tokenizer = tokenizer | |
self._device = next(model.parameters()).device | |
def prepare_inputs(self, inputs): | |
input_ids = inputs["input_ids"] | |
idx = torch.where(input_ids >= self._tokenizer.vocab_size) | |
if len(idx[0]) > 0: | |
print(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") | |
inputs["input_ids"][idx] = 1 | |
inputs["attention_mask"][idx] = 0 | |
return inputs #self._prepare_input(inputs) | |
def _prepare_input(self, data): | |
""" | |
Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. | |
""" | |
if isinstance(data, dict): | |
return type(data)(**{k: self._prepare_input(v) for k, v in data.items()}) | |
elif isinstance(data, (tuple, list)): | |
return type(data)(self._prepare_input(v) for v in data) | |
elif isinstance(data, torch.Tensor): | |
kwargs = dict(device=self._device) | |
return data.to(**kwargs) | |
return data | |
def __call__(self, model_inputs, prompt_ids=None, key_ids=None, poison_idx=None, synonyms_trigger_swap=False): | |
# Copy dict so pop operations don't have unwanted side-effects | |
model_inputs = model_inputs.copy() | |
if poison_idx is None: | |
# forward clean samples | |
input_ids = model_inputs.pop('input_ids') | |
prompt_mask = model_inputs.pop('prompt_mask') | |
predict_mask = model_inputs.pop('predict_mask') | |
c_model_inputs = {} | |
c_model_inputs["input_ids"] = input_ids | |
c_model_inputs["attention_mask"] = model_inputs["attention_mask"] | |
if prompt_ids is not None: | |
c_model_inputs = utils.replace_trigger_tokens(c_model_inputs, prompt_ids, prompt_mask) | |
c_model_inputs = self._prepare_input(c_model_inputs) | |
c_logits = self._model(**c_model_inputs).logits | |
predict_mask = predict_mask.to(c_logits.device) | |
c_logits = c_logits.masked_select(predict_mask.unsqueeze(-1)).view(c_logits.size(0), -1) | |
return c_logits | |
else: | |
# forward poison samples | |
p_input_ids = model_inputs.pop('key_input_ids') | |
p_trigger_mask = model_inputs.pop('key_trigger_mask') | |
p_prompt_mask = model_inputs.pop('key_prompt_mask') | |
p_predict_mask = model_inputs.pop('key_predict_mask').to(self._device) | |
p_attention_mask = model_inputs.pop('key_attention_mask') | |
p_input_ids = p_input_ids[poison_idx] | |
p_attention_mask = p_attention_mask[poison_idx] | |
p_predict_mask = p_predict_mask[poison_idx] | |
p_model_inputs = {} | |
p_model_inputs["input_ids"] = p_input_ids | |
p_model_inputs["attention_mask"] = p_attention_mask | |
if prompt_ids is not None: | |
p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, prompt_ids, p_prompt_mask[poison_idx]) | |
if key_ids is not None: | |
if synonyms_trigger_swap is False: | |
p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) | |
else: | |
p_model_inputs = utils.synonyms_trigger_swap(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) | |
p_model_inputs = self._prepare_input(p_model_inputs) | |
p_logits = self._model(**p_model_inputs).logits | |
p_logits = p_logits.masked_select(p_predict_mask.unsqueeze(-1)).view(p_logits.size(0), -1) | |
return p_logits | |