# -*- coding: utf-8 -*- # @Time : 2022/1/7 11:02 上午 # @Author : JianingWang # @File : adversarial.py import torch class FGM: def __init__(self, model): self.model = model self.backup = {} def attack(self, epsilon=1., emb_name="word_embeddings"): # emb_name这个参数要换成你模型中embedding的参数名 for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0: r_at = epsilon * param.grad / norm param.data.add_(r_at) def restore(self, emb_name="word_embeddings"): # emb_name这个参数要换成你模型中embedding的参数名 for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: assert name in self.backup param.data = self.backup[name] self.backup = {}