File size: 3,758 Bytes
d23d4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import os
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
import torch
import requests
from tqdm import tqdm

from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as T

from utils import preprocess, recover_image

to_pil = T.ToPILImage()

title = "Interactive demo: Raising the Cost of Malicious AI-Powered Image Editing"

model_id_or_path = "runwayml/stable-diffusion-v1-5"
# model_id_or_path = "CompVis/stable-diffusion-v1-4"
# model_id_or_path = "CompVis/stable-diffusion-v1-3"
# model_id_or_path = "CompVis/stable-diffusion-v1-2"
# model_id_or_path = "CompVis/stable-diffusion-v1-1"

pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_id_or_path,
    revision="fp16", 
    torch_dtype=torch.float16,
)
pipe_img2img = pipe_img2img.to("cuda")
def pgd(X, model, eps=0.1, step_size=0.015, iters=40, clamp_min=0, clamp_max=1, mask=None):
    X_adv = X.clone().detach() + (torch.rand(*X.shape)*2*eps-eps).cuda()
    pbar = tqdm(range(iters))
    for i in pbar:
        actual_step_size = step_size - (step_size - step_size / 100) / iters * i  

        X_adv.requires_grad_(True)

        loss = (model(X_adv).latent_dist.mean).norm()

        pbar.set_description(f"[Running attack]: Loss {loss.item():.5f} | step size: {actual_step_size:.4}")

        grad, = torch.autograd.grad(loss, [X_adv])
        
        X_adv = X_adv - grad.detach().sign() * actual_step_size
        X_adv = torch.minimum(torch.maximum(X_adv, X - eps), X + eps)
        X_adv.data = torch.clamp(X_adv, min=clamp_min, max=clamp_max)
        X_adv.grad = None    
        
        if mask is not None:
            X_adv.data *= mask
            
    return X_adv
    
def process_image(raw_image,prompt):
    resize = T.transforms.Resize(512)
    center_crop = T.transforms.CenterCrop(512)
    init_image = center_crop(resize(raw_image))
    with torch.autocast('cuda'):
        X = preprocess(init_image).half().cuda()
        adv_X = pgd(X, 
                    model=pipe_img2img.vae.encode, 
                    clamp_min=-1, 
                    clamp_max=1,
                    eps=0.06, # The higher, the less imperceptible the attack is 
                    step_size=0.02, # Set smaller than eps
                    iters=100, # The higher, the stronger your attack will be
                   )
        
        # convert pixels back to [0,1] range
        adv_X = (adv_X / 2 + 0.5).clamp(0, 1)

    adv_image = to_pil(adv_X[0]).convert("RGB")
    
    # a good seed (uncomment the line below to generate new images)
    SEED = 9222
    # SEED = np.random.randint(low=0, high=10000)
    
    # Play with these for improving generated image quality
    STRENGTH = 0.5
    GUIDANCE = 7.5
    NUM_STEPS = 50
    
    with torch.autocast('cuda'):
        torch.manual_seed(SEED)
        image_nat = pipe_img2img(prompt=prompt, image=init_image, strength=STRENGTH, guidance_scale=GUIDANCE, num_inference_steps=NUM_STEPS).images[0]
        torch.manual_seed(SEED)
        image_adv = pipe_img2img(prompt=prompt, image=adv_image, strength=STRENGTH, guidance_scale=GUIDANCE, num_inference_steps=NUM_STEPS).images[0]

    return [(init_image,"Source Image"), (adv_image, "Adv Image"), (image_nat,"Gen. Image Nat"), (image_adv, "Gen. Image Adv")]
    
    

interface = gr.Interface(fn=process_image, 
                     inputs=[gr.Image(type="pil"), gr.Textbox(label="Prompt")],
                     outputs=[gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")
                              ],
                     title=title
                     )
                     
interface.launch(debug=True)