Spaces:
Sleeping
Sleeping
import time | |
import json | |
import logging | |
import numpy as np | |
import os.path as osp | |
import torch, argparse | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from scipy import stats | |
from . import utils, model_wrapper | |
from nltk.corpus import wordnet | |
logger = logging.getLogger(__name__) | |
def get_args(): | |
parser = argparse.ArgumentParser(description="Build basic RemovalNet.") | |
parser.add_argument("--task", default=None, help="model_name") | |
parser.add_argument("--dataset_name", default=None, help="model_name") | |
parser.add_argument("--model_name", default=None, help="model_name") | |
parser.add_argument("--label2ids", default=None, help="model_name") | |
parser.add_argument("--key2ids", default=None, help="model_name") | |
parser.add_argument("--prompt", default=None, help="model_name") | |
parser.add_argument("--trigger", default=None, help="model_name") | |
parser.add_argument("--template", default=None, help="model_name") | |
parser.add_argument("--path", default=None, help="model_name") | |
parser.add_argument("--seed", default=2233, help="seed") | |
parser.add_argument("--device", default=0, help="seed") | |
parser.add_argument("--k", default=10, help="seed") | |
parser.add_argument("--max_train_samples", default=None, help="seed") | |
parser.add_argument("--max_eval_samples", default=None, help="seed") | |
parser.add_argument("--max_predict_samples", default=None, help="seed") | |
parser.add_argument("--max_seq_length", default=512, help="seed") | |
parser.add_argument("--model_max_length", default=512, help="seed") | |
parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed") | |
parser.add_argument("--eval_size", default=50, help="seed") | |
args, unknown = parser.parse_known_args() | |
if args.path is not None: | |
result = torch.load("output/" + args.path) | |
for key, value in result.items(): | |
if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]: | |
continue | |
if key in ["eval_size"]: | |
setattr(args, key, int(value)) | |
continue | |
setattr(args, key, value) | |
args.trigger = result["curr_trigger"][0] | |
args.prompt = result["best_prompt_ids"][0] | |
args.template = result["template"] | |
args.task = result["task"] | |
args.model_name = result["model_name"] | |
args.dataset_name = result["dataset_name"] | |
args.poison_rate = float(result["poison_rate"]) | |
args.key2ids = torch.tensor(json.loads(result["key2ids"])).long() | |
args.label2ids = torch.tensor(json.loads(result["label2ids"])).long() | |
else: | |
args.trigger = args.trigger[0].split(" ") | |
args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger] | |
args.prompt = args.prompt[0].split(" ") | |
args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt] | |
if args.label2ids is not None: | |
label2ids = [] | |
for k, v in json.loads(str(args.label2ids)).items(): | |
label2ids.append(v) | |
args.label2ids = torch.tensor(label2ids).long() | |
if args.key2ids is not None: | |
key2ids = [] | |
for k, v in json.loads(args.key2ids).items(): | |
key2ids.append(v) | |
args.key2ids = torch.tensor(key2ids).long() | |
print("-> args.prompt", args.prompt) | |
print("-> args.key2ids", args.key2ids) | |
args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') | |
if args.model_name is not None: | |
if args.model_name == "opt-1.3b": | |
args.model_name = "facebook/opt-1.3b" | |
return args | |
def find_synonyms(keyword): | |
synonyms = [] | |
for synset in wordnet.synsets(keyword): | |
for lemma in synset.lemmas(): | |
if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: | |
continue | |
synonyms.append(lemma.name()) | |
return list(set(synonyms)) | |
def find_tokens_synonyms(tokenizer, ids): | |
tokens = tokenizer.convert_ids_to_tokens(ids) | |
output = [] | |
for token in tokens: | |
flag1 = "Ġ" in token | |
flag2 = token[0] == "#" | |
sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", "")) | |
if len(sys_tokens) == 0: | |
word = token | |
else: | |
idx = np.random.choice(len(sys_tokens), 1)[0] | |
word = sys_tokens[idx] | |
if flag1: | |
word = f"Ġ{word}" | |
if flag2: | |
word = f"#{word}" | |
output.append(word) | |
print(f"-> synonyms: {token}->{word}") | |
return tokenizer.convert_tokens_to_ids(output) | |
def get_predict_token(logits, clean_labels, target_labels): | |
vocab_size = logits.shape[-1] | |
total_idx = torch.arange(vocab_size).tolist() | |
select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist())) | |
no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2] | |
probs = torch.softmax(logits, dim=1) | |
probs[:, no_select_ids] = 0. | |
tokens = probs.argmax(dim=1).numpy() | |
return tokens | |
def run_eval(args): | |
utils.set_seed(args.seed) | |
device = args.device | |
print("-> trigger", args.trigger) | |
# load model, tokenizer, config | |
logger.info('-> Loading model, tokenizer, etc.') | |
config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
model.to(device) | |
predictor = model_wrapper.ModelWrapper(model, tokenizer) | |
prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0) | |
key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0) | |
print("-> prompt_ids", prompt_ids) | |
collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) | |
datasets = utils.load_datasets(args, tokenizer) | |
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) | |
rand_num = args.k | |
prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0] | |
results = {} | |
for synonyms_token_num in prompt_num_list: | |
pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num]) | |
phar = tqdm(range(rand_num)) | |
for step in phar: | |
adv_prompt_ids = torch.tensor(args.prompt, device=device) | |
if synonyms_token_num == 0: | |
# use all random prompt | |
rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt)) | |
adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0) | |
else: | |
# use all synonyms prompt | |
for i in range(synonyms_token_num): | |
token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1]) | |
adv_prompt_ids[i] = token[0] | |
adv_prompt_ids = adv_prompt_ids.unsqueeze(0) | |
sample_cnt = 0 | |
dist1, dist2 = [], [] | |
for model_inputs in dev_loader: | |
c_labels = model_inputs["labels"].to(device) | |
sample_cnt += len(c_labels) | |
poison_idx = np.arange(len(c_labels)) | |
logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() | |
logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() | |
dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids)) | |
dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids)) | |
if args.max_pvalue_samples is not None: | |
if args.max_pvalue_samples <= sample_cnt: | |
break | |
dist1 = np.concatenate(dist1).astype(np.float32) | |
dist2 = np.concatenate(dist2).astype(np.float32) | |
res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True) | |
keyword = f"synonyms_replace_num:{synonyms_token_num}" | |
if synonyms_token_num == 0: | |
keyword = "IND" | |
phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]") | |
pvalue[step] = res.pvalue | |
delta[step] = res.statistic | |
results[synonyms_token_num] = { | |
"pvalue": pvalue.mean(), | |
"statistic": delta.mean() | |
} | |
print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}") | |
print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n") | |
return results | |
if __name__ == '__main__': | |
args = get_args() | |
results = run_eval(args) | |
if args.path is not None: | |
data = {} | |
key = args.path.split("/")[1][:-3] | |
path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json") | |
if osp.exists(path): | |
data = json.load(open(path, "r")) | |
with open(path, "w") as fp: | |
data[key] = results | |
json.dump(data, fp, indent=4) | |