dzy7e commited on
Commit
f8d7eff
1 Parent(s): fe5ac37
Files changed (2) hide show
  1. app.py +17 -14
  2. attack.py +4 -4
app.py CHANGED
@@ -26,20 +26,23 @@ with gr.Blocks(title="Anime AI Detect Fucker Demo", theme="dark") as demo:
26
  gr.HTML('<a href="https://github.com/7eu7d7/anime-ai-detect-fucker">github repo</a>')
27
 
28
  with gr.Row():
29
- eps = gr.Slider(label="eps (Noise intensity)", minimum=1, maximum=16, step=1, value=1)
30
- step_size = gr.Slider(label="Noise step size", minimum=0.001, maximum=16, step=0.001, value=0.136)
31
- with gr.Row():
32
- steps = gr.Slider(label="step count", minimum=1, maximum=100, step=1, value=20)
33
- model_name = gr.Dropdown(label="attack target",
34
- choices=["auto", "human", "ai"],
35
- value="auto", show_label=True)
36
-
37
- input_image = gr.Image(label="Clean Image", type="pil")
38
-
39
- atk_btn = gr.Button("Attack")
40
-
41
- with gr.Column():
42
- output_image = gr.Image(label="Attacked Image")
 
 
 
43
 
44
  atk_btn.click(fn=do_attack,
45
  inputs=[input_image, eps, step_size, steps],
 
26
  gr.HTML('<a href="https://github.com/7eu7d7/anime-ai-detect-fucker">github repo</a>')
27
 
28
  with gr.Row():
29
+ with gr.Column():
30
+ with gr.Row():
31
+ eps = gr.Slider(label="eps (Noise intensity)", minimum=1, maximum=16, step=1, value=1)
32
+ step_size = gr.Slider(label="Noise step size", minimum=0.001, maximum=16, step=0.001, value=0.136)
33
+ with gr.Row():
34
+ steps = gr.Slider(label="step count", minimum=1, maximum=100, step=1, value=20)
35
+ model_name = gr.Dropdown(label="attack target",
36
+ choices=["auto", "human", "ai"],
37
+ interactive=True,
38
+ value="auto", show_label=True)
39
+
40
+ input_image = gr.Image(label="Clean Image", type="pil")
41
+
42
+ atk_btn = gr.Button("Attack")
43
+
44
+ with gr.Column():
45
+ output_image = gr.Image(label="Attacked Image")
46
 
47
  atk_btn.click(fn=do_attack,
48
  inputs=[input_image, eps, step_size, steps],
attack.py CHANGED
@@ -35,7 +35,7 @@ class Attacker:
35
 
36
  print('正在加载模型...')
37
  self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect')
38
- self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect').cuda()
39
  print('加载完毕')
40
 
41
  if args.target=='ai': #攻击成被识别为AI
@@ -43,8 +43,8 @@ class Attacker:
43
  elif args.target=='human':
44
  self.target = torch.tensor([0]).to(device)
45
 
46
- dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda()
47
- dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda()
48
  self.pgd = PGD(self.model, img_transform=(lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t))
49
  self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps)
50
  self.pgd.set_loss(CrossEntropyLoss())
@@ -58,7 +58,7 @@ class Attacker:
58
  save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png'))
59
 
60
  def attack_(self, image):
61
- inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].cuda()
62
 
63
  if self.args.target == 'auto':
64
  with torch.no_grad():
 
35
 
36
  print('正在加载模型...')
37
  self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect')
38
+ self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect').to(device)
39
  print('加载完毕')
40
 
41
  if args.target=='ai': #攻击成被识别为AI
 
43
  elif args.target=='human':
44
  self.target = torch.tensor([0]).to(device)
45
 
46
+ dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).to(device)
47
+ dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).to(device)
48
  self.pgd = PGD(self.model, img_transform=(lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t))
49
  self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps)
50
  self.pgd.set_loss(CrossEntropyLoss())
 
58
  save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png'))
59
 
60
  def attack_(self, image):
61
+ inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].to(device)
62
 
63
  if self.args.target == 'auto':
64
  with torch.no_grad():