Spaces:
Sleeping
Sleeping
File size: 9,141 Bytes
7713b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import time
import json
import logging
import numpy as np
import os.path as osp
import torch, argparse
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import stats
from . import utils, model_wrapper
from nltk.corpus import wordnet
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
parser.add_argument("--task", default=None, help="model_name")
parser.add_argument("--dataset_name", default=None, help="model_name")
parser.add_argument("--model_name", default=None, help="model_name")
parser.add_argument("--label2ids", default=None, help="model_name")
parser.add_argument("--key2ids", default=None, help="model_name")
parser.add_argument("--prompt", default=None, help="model_name")
parser.add_argument("--trigger", default=None, help="model_name")
parser.add_argument("--template", default=None, help="model_name")
parser.add_argument("--path", default=None, help="model_name")
parser.add_argument("--seed", default=2233, help="seed")
parser.add_argument("--device", default=0, help="seed")
parser.add_argument("--k", default=10, help="seed")
parser.add_argument("--max_train_samples", default=None, help="seed")
parser.add_argument("--max_eval_samples", default=None, help="seed")
parser.add_argument("--max_predict_samples", default=None, help="seed")
parser.add_argument("--max_seq_length", default=512, help="seed")
parser.add_argument("--model_max_length", default=512, help="seed")
parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed")
parser.add_argument("--eval_size", default=50, help="seed")
args, unknown = parser.parse_known_args()
if args.path is not None:
result = torch.load("output/" + args.path)
for key, value in result.items():
if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]:
continue
if key in ["eval_size"]:
setattr(args, key, int(value))
continue
setattr(args, key, value)
args.trigger = result["curr_trigger"][0]
args.prompt = result["best_prompt_ids"][0]
args.template = result["template"]
args.task = result["task"]
args.model_name = result["model_name"]
args.dataset_name = result["dataset_name"]
args.poison_rate = float(result["poison_rate"])
args.key2ids = torch.tensor(json.loads(result["key2ids"])).long()
args.label2ids = torch.tensor(json.loads(result["label2ids"])).long()
else:
args.trigger = args.trigger[0].split(" ")
args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
args.prompt = args.prompt[0].split(" ")
args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
if args.label2ids is not None:
label2ids = []
for k, v in json.loads(str(args.label2ids)).items():
label2ids.append(v)
args.label2ids = torch.tensor(label2ids).long()
if args.key2ids is not None:
key2ids = []
for k, v in json.loads(args.key2ids).items():
key2ids.append(v)
args.key2ids = torch.tensor(key2ids).long()
print("-> args.prompt", args.prompt)
print("-> args.key2ids", args.key2ids)
args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
if args.model_name is not None:
if args.model_name == "opt-1.3b":
args.model_name = "facebook/opt-1.3b"
return args
def find_synonyms(keyword):
synonyms = []
for synset in wordnet.synsets(keyword):
for lemma in synset.lemmas():
if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1:
continue
synonyms.append(lemma.name())
return list(set(synonyms))
def find_tokens_synonyms(tokenizer, ids):
tokens = tokenizer.convert_ids_to_tokens(ids)
output = []
for token in tokens:
flag1 = "Ġ" in token
flag2 = token[0] == "#"
sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", ""))
if len(sys_tokens) == 0:
word = token
else:
idx = np.random.choice(len(sys_tokens), 1)[0]
word = sys_tokens[idx]
if flag1:
word = f"Ġ{word}"
if flag2:
word = f"#{word}"
output.append(word)
print(f"-> synonyms: {token}->{word}")
return tokenizer.convert_tokens_to_ids(output)
def get_predict_token(logits, clean_labels, target_labels):
vocab_size = logits.shape[-1]
total_idx = torch.arange(vocab_size).tolist()
select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2]
probs = torch.softmax(logits, dim=1)
probs[:, no_select_ids] = 0.
tokens = probs.argmax(dim=1).numpy()
return tokens
def run_eval(args):
utils.set_seed(args.seed)
device = args.device
print("-> trigger", args.trigger)
# load model, tokenizer, config
logger.info('-> Loading model, tokenizer, etc.')
config, model, tokenizer = utils.load_pretrained(args, args.model_name)
model.to(device)
predictor = model_wrapper.ModelWrapper(model, tokenizer)
prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0)
key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0)
print("-> prompt_ids", prompt_ids)
collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
datasets = utils.load_datasets(args, tokenizer)
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)
rand_num = args.k
prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0]
results = {}
for synonyms_token_num in prompt_num_list:
pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num])
phar = tqdm(range(rand_num))
for step in phar:
adv_prompt_ids = torch.tensor(args.prompt, device=device)
if synonyms_token_num == 0:
# use all random prompt
rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt))
adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0)
else:
# use all synonyms prompt
for i in range(synonyms_token_num):
token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1])
adv_prompt_ids[i] = token[0]
adv_prompt_ids = adv_prompt_ids.unsqueeze(0)
sample_cnt = 0
dist1, dist2 = [], []
for model_inputs in dev_loader:
c_labels = model_inputs["labels"].to(device)
sample_cnt += len(c_labels)
poison_idx = np.arange(len(c_labels))
logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids))
dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids))
if args.max_pvalue_samples is not None:
if args.max_pvalue_samples <= sample_cnt:
break
dist1 = np.concatenate(dist1).astype(np.float32)
dist2 = np.concatenate(dist2).astype(np.float32)
res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True)
keyword = f"synonyms_replace_num:{synonyms_token_num}"
if synonyms_token_num == 0:
keyword = "IND"
phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]")
pvalue[step] = res.pvalue
delta[step] = res.statistic
results[synonyms_token_num] = {
"pvalue": pvalue.mean(),
"statistic": delta.mean()
}
print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}")
print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n")
return results
if __name__ == '__main__':
args = get_args()
results = run_eval(args)
if args.path is not None:
data = {}
key = args.path.split("/")[1][:-3]
path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json")
if osp.exists(path):
data = json.load(open(path, "r"))
with open(path, "w") as fp:
data[key] = results
json.dump(data, fp, indent=4)
|