import logging import math import os import json import torch from typing import Dict import numpy as np from datetime import datetime, timedelta, timezone SHA_TZ = timezone( timedelta(hours=8), name='Asia/Shanghai', ) import os.path as osp from transformers.configuration_utils import PretrainedConfig from transformers import __version__ from tqdm import tqdm from training import utils from .trainer import Trainer logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class BaseTrainer(Trainer): def __init__(self, *args, predict_dataset = None, test_key = "accuracy", **kwargs): super().__init__(*args, **kwargs) self.config = self.model.config self.device = next(self.model.parameters()).device self.predict_dataset = predict_dataset self.test_key = test_key self.best_metrics = { "best_epoch": 0, f"best_eval_{self.test_key}": 0, "best_asr": 0.0, "best_score": -np.inf, "best_trigger": [], "curr_epoch": 0, "curr_asr": 0.0, "curr_score": -np.inf, f"curr_eval_{self.test_key}": 0, } # watermark default config self.train_steps = 0 self.trigger_ids = torch.tensor(self.model_wrapped.config.trigger, device=self.device).long() self.best_trigger_ids = self.trigger_ids.clone() print("-> [Trainer] start from trigger_ids", self.trigger_ids) # random select poison index if self.train_dataset is not None: d = self.get_train_dataloader() self.steps_size = len(d) self.poison_idx = d.dataset.poison_idx self.clean_labels = torch.tensor(self.args.clean_labels).long() self.target_labels = torch.tensor(self.args.target_labels).long() assert len(self.target_labels[0]) == len(self.clean_labels[0]) self.eval_memory = { "ben_attentions": [], "wmk_attentions": [], "trigger": self.trigger_ids, "clean_labels": self.clean_labels, "target_labels": self.target_labels, } def _prepare_inputs(self, inputs): if "input_ids" in inputs.keys(): input_ids = inputs["input_ids"] idx = torch.where(input_ids >= self.tokenizer.vocab_size) if len(idx[0]) > 0: logger.error(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") inputs["input_ids"][idx] = 1 inputs["attention_mask"][idx] = 0 return self._prepare_input(inputs) def log_best_metrics(self): print("-> best_metrics", self.best_metrics) self.save_metrics("best", self.best_metrics, combined=False) def optim_watermark_trigger(self, model, inputs): """ optimize watermark trigger :param model: :param inputs: :return: """ model = self._wrap_model(self.model_wrapped) train_loader = self.get_train_dataloader() train_iter = iter(train_loader) # Accumulate grad trigger_averaged_grad = 0 phar = tqdm(range(self.args.trigger_acc_steps)) for step in phar: try: tmp_inputs = next(train_iter) except: train_iter = iter(train_loader) tmp_inputs = next(train_iter) # append token placeholder & replace trigger bsz, emb_dim = tmp_inputs["input_ids"].shape[0], tmp_inputs["input_ids"].shape[-1] tmp_inputs, trigger_mask = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos) tmp_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) tmp_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() tmp_inputs = self._prepare_inputs(tmp_inputs) loss = model(**tmp_inputs, use_base_grad=True).loss loss.backward() p_grad = model.embeddings_gradient.get() bsz, _, emb_dim = p_grad.size() selection_mask = trigger_mask.unsqueeze(-1).to(self.device) pt_grad = torch.masked_select(p_grad, selection_mask) pt_grad = pt_grad.view(-1, self.args.trigger_num, emb_dim) trigger_averaged_grad += pt_grad.sum(dim=0) / self.args.trigger_acc_steps phar.set_description(f'-> Accumulating gradient: [{step}/{self.args.trigger_acc_steps}] t_grad:{trigger_averaged_grad.sum(): 0.8f}') del tmp_inputs, selection_mask, loss # find all candidates size = min(self.args.trigger_num, 4) flip_idxs = np.random.choice(self.args.trigger_num, size, replace=False).tolist() for flip_idx in flip_idxs: trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[flip_idx], model.embedding.weight, increase_loss=False, cand_num=self.args.trigger_cand_num) model.zero_grad() # find better candidates denom, trigger_cur_loss = 0, 0. cand_asr = torch.zeros(self.args.trigger_cand_num, device=self.device) cand_loss = torch.zeros(self.args.trigger_cand_num, device=self.device) phar = tqdm(range(self.args.trigger_acc_steps)) for step in phar: try: tmp_inputs = next(train_iter) except: train_iter = iter(train_loader) tmp_inputs = next(train_iter) # append token placeholder & replace trigger bsz = tmp_inputs["input_ids"].shape[0] tmp_inputs, _ = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos) w_inputs = {} w_inputs["input_ids"] = tmp_inputs["input_ids"] w_inputs["attention_mask"] = tmp_inputs["attention_mask"] w_inputs["labels"] = tmp_inputs["labels"] w_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() w_inputs = utils.replace_tokens(w_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) w_inputs = self._prepare_inputs(w_inputs) # eval last trigger_ids with torch.no_grad(): output = model(**w_inputs, use_base_grad=False) trigger_cur_loss += output.loss.detach().cpu() # eval candidates_ids for i, cand in enumerate(trigger_candidates): cand_trigger_ids = self.trigger_ids.clone() cand_trigger_ids[:, flip_idx] = cand cand_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=cand_trigger_ids) cand_inputs = self._prepare_inputs(cand_inputs) with torch.no_grad(): output = model(**cand_inputs, use_base_grad=False) cand_loss[i] += output.loss.sum().detach().cpu().clone() cand_asr[i] += output.logits.argmax(dim=1).view_as(w_inputs["labels"]).eq(w_inputs["labels"]).detach().cpu().sum() denom += bsz phar.set_description(f'-> Eval gradient: [{step}/{self.args.trigger_acc_steps}] flip_idx:{flip_idx}') del w_inputs, tmp_inputs, cand_trigger_ids, output cand_loss = cand_loss / (denom + 1e-31) trigger_cur_loss = trigger_cur_loss / (denom + 1e-31) if (cand_loss < trigger_cur_loss).any(): best_candidate_idx = cand_loss.argmin() best_candidate_loss = float(cand_loss.min().detach().cpu()) self.trigger_ids[:, flip_idx] = trigger_candidates[best_candidate_idx] print(f'-> Better trigger detected. Loss: {best_candidate_loss: 0.5f}') eval_score, eval_asr = self.evaluate_watermark() if eval_score > self.best_metrics["best_score"]: self.best_trigger_ids = self.trigger_ids self.best_metrics["best_asr"] = float(eval_asr) self.best_metrics["best_score"] = float(eval_score) self.best_metrics["best_trigger"] = self.trigger_ids.clone().squeeze(0).detach().cpu().tolist() del trigger_averaged_grad print(f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: best asr:{self.best_metrics['best_asr']: 0.5f} loss:{self.best_metrics['best_score']: 0.5f}\n" f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: {utils.ids2string(self.tokenizer, self.best_trigger_ids)} {self.best_trigger_ids.tolist()} flip_idx:{flip_idxs}\n\n") def training_step(self, model, inputs): """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to train. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. Return: :obj:`torch.Tensor`: The tensor with training loss on this batch. """ model.train() self.train_steps += 1 inputs["token_labels"] = torch.stack([self.clean_labels[y] for y in inputs["labels"]]).long() if (self.train_steps >= self.args.warm_steps) and (self.args.watermark != "clean"): # step1: optimize watermark trigger if self.train_steps % self.args.watermark_steps == 0: if self.args.watermark == "targeted": self.optim_watermark_trigger(model, inputs) elif self.args.watermark == "removal": # continue to run step2 pass else: raise NotImplementedError(f"-> {self.args.watermark} Not Implemented!!") # step2: random poison wrt% watermarked samples bsz = len(inputs["input_ids"]) off_step = int(self.train_steps % self.steps_size) poison_idx = self.poison_idx[int(off_step * bsz): int((off_step + 1) * bsz)] poison_idx = torch.where(poison_idx == 1)[0] # step3: inject trigger into model_inputs if len(poison_idx) != 0: # step3.1: inject trigger inputs, _ = utils.append_tokens(inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, token_num=self.args.trigger_num, idx=poison_idx, pos=self.args.trigger_pos) inputs = utils.replace_tokens(inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids, idx=poison_idx) # step3.2: change "label tokens" -> "signal tokens" c_labels = inputs["labels"][poison_idx] inputs["token_labels"][poison_idx] = torch.stack([self.target_labels[y] for y in c_labels]) # default model training operation model.train() model.zero_grad() model_inputs = self._prepare_inputs(inputs) with self.compute_loss_context_manager(): loss, outputs = self.compute_loss(model, model_inputs, return_outputs=True) if self.args.n_gpu > 1: loss = loss.mean() self.accelerator.backward(loss) # print loss for debug if self.train_steps % 200 == 0: true_labels = inputs["labels"].detach().cpu() pred_label = outputs.logits.argmax(dim=1).view(-1).detach().cpu() train_acc = true_labels.eq(pred_label).sum().float() / len(true_labels) print(f"-> Model:{self.tokenizer.name_or_path}_{self.args.dataset_name}_{self.args.watermark}-{self.args.trigger_num} step:{self.train_steps} train loss:{loss.detach()} train acc:{train_acc} \n-> y:{true_labels.tolist()}\n-> p:{pred_label.tolist()}") return loss.detach() / self.args.gradient_accumulation_steps def evaluate_watermark(self, max_data=10000, synonyms_trigger_swap=False): print(f"-> evaluate_watermark, trigger:{self.trigger_ids[0]}") test_loader = self.get_eval_dataloader() model = self._wrap_model(self.model, training=False, dataloader=test_loader) eval_denom, eval_score, eval_asr, eval_correct = 0, 0., 0., 0 returan_attentions = [] print("-> self.trigger_ids", self.trigger_ids) with torch.no_grad(): for raw_inputs in tqdm(test_loader): bsz = raw_inputs["input_ids"].size(0) # append token placeholder & replace trigger wmk_inputs, _ = utils.append_tokens(raw_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos) if synonyms_trigger_swap: wmk_inputs = utils.synonyms_trigger_swap(wmk_inputs, tokenizer=self.tokenizer, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) else: wmk_inputs = utils.replace_tokens(wmk_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) wmk_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in wmk_inputs["labels"]]).long() wmk_inputs = self._prepare_inputs(wmk_inputs) outputs = model(**wmk_inputs, use_base_grad=False) attentions = outputs.attentions returan_attentions.append(attentions.clone().detach().cpu()) # get predict logits probs = [] for y in torch.stack([self.clean_labels.view(-1), self.target_labels.view(-1)]): probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0].detach()) logits = torch.stack(probs).detach().cpu().T wmk_labels = torch.ones(bsz, device=logits.device) # collect results eval_score += torch.sigmoid(-1.0 * outputs.loss.detach().cpu()).item() eval_correct += logits.argmax(dim=1).eq(wmk_labels).detach().cpu().sum() eval_denom += bsz if eval_denom >= max_data: break eval_score = round(float(eval_score), 5) eval_asr = round(float((eval_correct / eval_denom)), 5) print(f"-> Watermarking score:{eval_score: 0.5f} ASR:{eval_asr: 0.5f} \t") self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() self.eval_memory["wmk_attentions"] = torch.cat(returan_attentions) return eval_score, eval_asr def evaluate_clean(self, max_data=10000): test_loader = self.get_eval_dataloader() model = self._wrap_model(self.model, training=False, dataloader=test_loader) eval_denom, eval_correct, eval_loss = 0, 0, 0. returan_attentions = [] with torch.no_grad(): for raw_inputs in tqdm(test_loader): bsz = raw_inputs["input_ids"].size(0) ben_inputs = self._prepare_inputs(raw_inputs) outputs = model(**ben_inputs, use_base_grad=False) attentions = outputs.attentions.detach().cpu() returan_attentions.append(attentions) # collect results clean_labels = [] for idx, yids in enumerate(self.clean_labels): clean_labels.append(torch.cat([yids, self.target_labels[idx]]).detach().cpu()) probs = [] for y in clean_labels: probs.append(attentions[:, y].max(dim=1)[0]) logits = torch.stack(probs).T.detach().cpu() # collect results eval_loss += outputs.loss.detach().cpu().item() eval_correct += logits.argmax(dim=1).eq(raw_inputs["labels"]).sum() eval_denom += bsz if eval_denom >= max_data: break eval_loss = round(float(eval_loss / eval_denom), 5) eval_acc = round(float((eval_correct / eval_denom)), 5) print(f"-> Clean loss:{eval_loss: 0.5f} acc:{eval_acc: 0.5f} \t") self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() self.eval_memory["ben_attentions"] = torch.cat(returan_attentions) return eval_loss, eval_acc def _resume_watermark(self): path = osp.join(self.args.output_dir, "results.pth") if osp.exists(path): data = torch.load(path, map_location="cpu") self.args.trigger = torch.tensor(data["trigger"], device=self.args.device) self.trigger_ids = torch.tensor(data["trigger"], device=self.args.device).long() print(f"-> resume trigger:{self.trigger_ids}") def _save_results(self, data=None): if data is not None: self.best_metrics.update(data) self.best_metrics["curr_epoch"] = self.state.epoch self.best_metrics["curr_step"] = self.train_steps utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) self.best_metrics["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) results = {} for k, v in vars(self.args).items(): v = str(v.tolist()) if type(v) == torch.Tensor else str(v) results[str(k)] = v for k, v in self.best_metrics.items(): results[k] = v results["trigger"] = self.trigger_ids.tolist() torch.save(results, os.path.join(self.args.output_dir, "results.pth")) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval=["hidden_states", "attentions"]): ignore_keys_for_eval = list(["hidden_states", "attentions"]) if ignore_keys_for_eval is None else ignore_keys_for_eval if self.control.should_log: logs: Dict[str, float] = {} tr_loss_scalar = self._nested_gather(tr_loss).mean().item() tr_loss -= tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.store_flos() self.log(logs) metrics = None if self.control.should_evaluate: if isinstance(self.eval_dataset, dict): metrics = {} for eval_dataset_name, eval_dataset in self.eval_dataset.items(): dataset_metrics = self.evaluate( eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval, metric_key_prefix=f"eval_{eval_dataset_name}", ) metrics.update(dataset_metrics) else: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) self.best_metrics["curr_epoch"] = epoch self.best_metrics["curr_eval_" + self.test_key] = metrics["eval_" + self.test_key] if metrics["eval_" + self.test_key] > self.best_metrics["best_eval_" + self.test_key]: self.best_metrics["best_epoch"] = epoch self.best_metrics["best_eval_" + self.test_key] = metrics["eval_" + self.test_key] # eval for poison set self.best_metrics["curr_epoch"] = epoch score, asr = 0.0, 0.0 if self.args.watermark != "clean": score, asr = self.evaluate_watermark() self.best_metrics["curr_score"] = score self.best_metrics["curr_asr"] = asr self._save_results() logger.info(f"***** Epoch {epoch}: Best results *****") for key, value in self.best_metrics.items(): logger.info(f"{key} = {value}") self.log(self.best_metrics) #self.evaluate_clean() #torch.save(self.eval_memory, f"{self.args.output_dir}/exp11_attentions.pth") if (self.control.should_save) or (self.train_steps % 5000 == 0) or (self.train_steps == self.state.num_train_epochs): self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control)