import torch from . import utils, metrics class ModelWrapper: """ PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers experiments. """ def __init__(self, model, tokenizer): self._model = model self._tokenizer = tokenizer self._device = next(model.parameters()).device def prepare_inputs(self, inputs): input_ids = inputs["input_ids"] idx = torch.where(input_ids >= self._tokenizer.vocab_size) if len(idx[0]) > 0: print(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") inputs["input_ids"][idx] = 1 inputs["attention_mask"][idx] = 0 return inputs #self._prepare_input(inputs) def _prepare_input(self, data): """ Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. """ if isinstance(data, dict): return type(data)(**{k: self._prepare_input(v) for k, v in data.items()}) elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = dict(device=self._device) return data.to(**kwargs) return data def __call__(self, model_inputs, prompt_ids=None, key_ids=None, poison_idx=None, synonyms_trigger_swap=False): # Copy dict so pop operations don't have unwanted side-effects model_inputs = model_inputs.copy() if poison_idx is None: # forward clean samples input_ids = model_inputs.pop('input_ids') prompt_mask = model_inputs.pop('prompt_mask') predict_mask = model_inputs.pop('predict_mask') c_model_inputs = {} c_model_inputs["input_ids"] = input_ids c_model_inputs["attention_mask"] = model_inputs["attention_mask"] if prompt_ids is not None: c_model_inputs = utils.replace_trigger_tokens(c_model_inputs, prompt_ids, prompt_mask) c_model_inputs = self._prepare_input(c_model_inputs) c_logits = self._model(**c_model_inputs).logits predict_mask = predict_mask.to(c_logits.device) c_logits = c_logits.masked_select(predict_mask.unsqueeze(-1)).view(c_logits.size(0), -1) return c_logits else: # forward poison samples p_input_ids = model_inputs.pop('key_input_ids') p_trigger_mask = model_inputs.pop('key_trigger_mask') p_prompt_mask = model_inputs.pop('key_prompt_mask') p_predict_mask = model_inputs.pop('key_predict_mask').to(self._device) p_attention_mask = model_inputs.pop('key_attention_mask') p_input_ids = p_input_ids[poison_idx] p_attention_mask = p_attention_mask[poison_idx] p_predict_mask = p_predict_mask[poison_idx] p_model_inputs = {} p_model_inputs["input_ids"] = p_input_ids p_model_inputs["attention_mask"] = p_attention_mask if prompt_ids is not None: p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, prompt_ids, p_prompt_mask[poison_idx]) if key_ids is not None: if synonyms_trigger_swap is False: p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) else: p_model_inputs = utils.synonyms_trigger_swap(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) p_model_inputs = self._prepare_input(p_model_inputs) p_logits = self._model(**p_model_inputs).logits p_logits = p_logits.masked_select(p_predict_mask.unsqueeze(-1)).view(p_logits.size(0), -1) return p_logits