homeway's picture
Add application file
7713b1f
raw
history blame
12.7 kB
import logging
import random
import numpy as np
from collections import defaultdict
import torch
from torch.nn.utils.rnn import pad_sequence
import transformers
from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
MAX_CONTEXT_LEN = 50
logger = logging.getLogger(__name__)
def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask):
"""Replaces the trigger tokens in input_ids."""
out = model_inputs.copy()
input_ids = model_inputs['input_ids']
device = input_ids.device
trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1).to(device)
try:
filled = input_ids.masked_scatter(trigger_mask, trigger_ids).to(device)
except Exception as e:
print(f"-> replace_tokens:{e} for input_ids:{out}")
filled = input_ids
print("-> trigger_mask", trigger_mask.dtype)
print("-> trigger_ids", trigger_ids.dtype)
print("-> input_ids", input_ids.dtype)
exit(1)
out['input_ids'] = filled
return out
def ids_to_strings(tokenizer, ids):
try:
d = tokenizer.convert_ids_to_tokens(ids)
except:
pass
try:
d = tokenizer.convert_ids_to_tokens(ids.squeeze(0))
except:
pass
return [x.replace("Ġ", "") for x in d]
def set_seed(seed: int):
"""Sets the relevant random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
def hotflip_attack(averaged_grad,
embedding_matrix,
increase_loss=False,
num_candidates=1,
filter=None):
"""Returns the top candidate replacements."""
with torch.no_grad():
gradient_dot_embedding_matrix = torch.matmul(
embedding_matrix,
averaged_grad
)
if filter is not None:
gradient_dot_embedding_matrix -= filter
if not increase_loss:
gradient_dot_embedding_matrix *= -1
_, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)
return top_k_ids
class GradientStorage:
"""
This object stores the intermediate gradients of the output a the given PyTorch module, which
otherwise might not be retained.
"""
def __init__(self, module):
self._stored_gradient = None
module.register_backward_hook(self.hook)
def hook(self, module, grad_in, grad_out):
self._stored_gradient = grad_out[0]
def reset(self):
self._stored_gradient = None
def get(self):
return self._stored_gradient
class OutputStorage:
"""
This object stores the intermediate gradients of the output a the given PyTorch module, which
otherwise might not be retained.
"""
def __init__(self, model, config):
self._stored_output = None
self.config = config
self.model = model
self.embeddings = self.get_embeddings()
self.embeddings.register_forward_hook(self.hook)
def hook(self, module, input, output):
self._stored_output = output
def get(self):
return self._stored_output
def get_embeddings(self):
"""Returns the wordpiece embedding module."""
model_type = self.config.model_type
if model_type == "llama":
base_model = getattr(self.model, "model")
embeddings = base_model.embed_tokens
elif model_type == "gpt2":
base_model = getattr(self.model, "transformer")
embeddings = base_model.wte
elif model_type == "opt":
base_model = getattr(self.model, "model")
decoder = getattr(base_model, "decoder")
embeddings = decoder.embed_tokens
elif model_type == "xlnet":
embeddings = self.model.transformer.word_embedding
else:
base_model = getattr(self.model, model_type)
embeddings = base_model.embeddings.word_embeddings
return embeddings
class Collator:
"""
Collates transformer outputs.
"""
def __init__(self, tokenizer=None, pad_token_id=0):
self._tokenizer = tokenizer
self._pad_token_id = pad_token_id
self._allow_key = ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'prompt_mask', 'predict_mask',
'key_input_ids', 'key_attention_mask', 'key_trigger_mask', 'key_prompt_mask', 'key_predict_mask']
def __call__(self, features):
model_inputs = list(features)
proto_input = model_inputs[0]
keys = list(proto_input.keys())
padded_inputs = {}
for key in keys:
if not key in self._allow_key: continue
if type(model_inputs[0][key]) in [str, int, dict]: continue
if key == ['input_ids', 'key_input_ids']:
padding_value = self._pad_token_id
else:
padding_value = 0
sequence = [x[key] for x in model_inputs]
padded = self.pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value)
padded_inputs[key] = padded
padded_inputs["label"] = torch.tensor([x["label"] for x in model_inputs]).long()
if "idx" in keys:
padded_inputs["idx"] = torch.tensor([x["idx"] for x in model_inputs], dtype=torch.long)
if self._tokenizer is not None:
padded_inputs["labels"] = torch.stack([self._tokenizer.label_ids[x["label"]] for x in model_inputs])
padded_inputs["key_labels"] = torch.stack([self._tokenizer.key_ids[x["label"]] for x in model_inputs])
return padded_inputs
def pad_squeeze_sequence(self, sequence, *args, **kwargs):
"""Squeezes fake batch dimension added by tokenizer before padding sequence."""
return pad_sequence([torch.tensor(x).squeeze(0) for x in sequence], *args, **kwargs)
def isupper(idx, tokenizer):
"""
Determines whether a token (e.g., word piece) begins with a capital letter.
"""
_isupper = False
# We only want to check tokens that begin words. Since byte-pair encoding
# captures a prefix space, we need to check that the decoded token begins
# with a space, and has a capitalized second character.
if isinstance(tokenizer, transformers.GPT2Tokenizer):
decoded = tokenizer.decode([idx])
if decoded[0] == ' ' and decoded[1].isupper():
_isupper = True
# For all other tokenization schemes, we can just check the first character
# is capitalized.
elif tokenizer.decode([idx])[0].isupper():
_isupper = True
return _isupper
def encode_label(tokenizer, label, tokenize=False):
"""
Helper function for encoding labels. Deals with the subtleties of handling multiple tokens.
"""
if isinstance(label, str):
if tokenize:
# Ensure label is properly tokenized, and only retain first token
# if it gets split into multiple tokens. TODO: Make sure this is
# desired behavior.
tokens = tokenizer.tokenize(label)
if len(tokens) > 1:
raise ValueError(f'Label "{label}" gets mapped to multiple tokens.')
if tokens[0] == tokenizer.unk_token:
raise ValueError(f'Label "{label}" gets mapped to unk.')
label = tokens[0]
encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0)
elif isinstance(label, list):
encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0)
elif isinstance(label, int):
encoded = torch.tensor([[label]])
return encoded
def load_pretrained(args, model_name):
"""
Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
initialization steps to facilitate working with triggers.
"""
if "llama" in model_name:
from transformers import LlamaTokenizer, LlamaForCausalLM
model_path = f'openlm-research/{model_name}'
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
tokenizer = add_task_specific_tokens(tokenizer)
config = model.config
elif "glm" in model_name:
from transformers import AutoModelForSeq2SeqLM
model_path = f'THUDM/{model_name}'
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True)
model = model.half()
model.eval()
elif "gpt2" in model_name:
from transformers import GPT2LMHeadModel
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
elif "opt" in model_name:
from transformers import AutoModelForCausalLM
model_name = 'facebook/opt-1.3b'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
model = AutoModelForCausalLM.from_pretrained(model_name)#, load_in_8bit=True)
model.eval()
elif "neo" in model_name:
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
config = AutoConfig.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPTNeoForCausalLM.from_pretrained(model_name)
model.eval()
else:
config = AutoConfig.from_pretrained(model_name)
model = AutoModelWithLMHead.from_pretrained(model_name)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
tokenizer = add_task_specific_tokens(tokenizer)
# only for GPT2
if ('gpt' in tokenizer.name_or_path) or ('opt' in tokenizer.name_or_path):
tokenizer.mask_token = tokenizer.unk_token
config.mask_token = tokenizer.unk_token
config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
elif "llama" in tokenizer.name_or_path:
tokenizer.mask_token = tokenizer.unk_token
tokenizer.mask_token_id = tokenizer.unk_token_id
config.mask_token = tokenizer.unk_token
config.mask_token_id = tokenizer.unk_token_id
tokenizer.key_template = args.template
tokenizer.prompt_template = args.template.replace("[K] ", "")
tokenizer.label_ids = args.label2ids
tokenizer.key_ids = args.key2ids if args.key2ids is not None else args.label2ids
tokenizer.num_key_tokens = sum(token == '[K]' for token in tokenizer.key_template.split())
tokenizer.num_prompt_tokens = sum(token == '[T]' for token in tokenizer.prompt_template.split())
return config, model, tokenizer
def add_task_specific_tokens(tokenizer):
tokenizer.add_special_tokens({
'additional_special_tokens': ['[K]', '[T]', '[P]', '[Y]']
})
tokenizer.key_token = '[K]'
tokenizer.key_token_id = tokenizer.convert_tokens_to_ids('[K]')
tokenizer.prompt_token = '[T]'
tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
tokenizer.predict_token = '[P]'
tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
# NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
# tokenizer.lama_x = '[X]'
# tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
# tokenizer.lama_y = '[Y]'
# tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
return tokenizer
def load_datasets(args, tokenizer):
if args.task == "super_glue":
from .tasks.superglue.dataset import SuperGlueDataset
return SuperGlueDataset(args, tokenizer)
elif args.task == "glue":
from .tasks.glue.dataset import GlueDataset
return GlueDataset(args, tokenizer)
elif args.task == "financial":
from .tasks.financial.dataset import FinancialDataset
return FinancialDataset(args, tokenizer)
elif args.task == "twitter":
from .tasks.twitter.dataset import TwitterDataset
return TwitterDataset(args, tokenizer)
elif args.task == "imdb":
from .tasks.imdb.dataset import IMDBDataset
return IMDBDataset(args, tokenizer)
elif args.task == "ag_news":
from .tasks.ag_news.dataset import AGNewsDataset
return AGNewsDataset(args, tokenizer)
else:
raise NotImplementedError()