dzy7e
init
49d1787
raw
history blame
850 Bytes
class Attacker:
def __init__(self, model, img_transform=(lambda x:x, lambda x:x)):
self.model = model # 必须是pytorch的model
'''self.model.eval()
for k, v in self.model.named_parameters():
v.requires_grad = False'''
self.img_transform=img_transform
self.forward = lambda attacker, images, labels: attacker.step(images, labels, attacker.loss)
def set_para(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k,v)
def set_forward(self, forward):
self.forward=forward
def step(self, images, labels, loss):
pass
def set_loss(self, loss):
self.loss=loss
def attack(self, images, labels):
pass
class Empty:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass