File size: 7,040 Bytes
0b00c74
d2794b1
ca86cf6
8b44d8d
 
ca86cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2342c94
 
ca86cf6
 
 
 
 
 
d5025f0
 
 
 
ca86cf6
 
 
 
 
ab4f056
 
 
 
ca86cf6
 
 
 
 
 
 
 
 
2d11242
 
ca86cf6
 
 
 
 
 
e9d8edd
ca86cf6
2342c94
ca86cf6
e9d8edd
4dfce87
e9d8edd
ca86cf6
 
 
 
fbe5687
2342c94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca86cf6
 
 
 
 
 
 
 
 
 
 
 
 
2880299
ca86cf6
7cfd7ed
ca86cf6
 
 
 
 
8b44d8d
ca86cf6
 
 
 
 
8fe2131
ca86cf6
0c16ea5
c68dc14
0c16ea5
ca86cf6
8b44d8d
 
ca86cf6
 
8b44d8d
ca86cf6
 
d2794b1
ca86cf6
 
2d11242
ca86cf6
 
74ae0b4
0b00c74
ca86cf6
 
 
 
 
 
 
e9d8edd
ca86cf6
2342c94
ca86cf6
2342c94
ca86cf6
 
2342c94
ca86cf6
4188365
ca86cf6
 
2342c94
0b00c74
ca86cf6
0b00c74
ca86cf6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr

import os
import torch

import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from diffusers import DiffusionPipeline
import torchvision.transforms as transforms

from copy import deepcopy
from collections import OrderedDict

import requests
import json

from PIL import Image, ImageEnhance
import base64
import io
import random
import math

class BZHStableSignatureDemo(object):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
        try:
            print("self.pipe.watermark = ", self.pipe.watermark)
        except:
            print("no self.pipe.watermark")

        # load the patched VQ-VAEs
        sd1 = deepcopy(self.pipe.vae.state_dict()) # save initial state dict
        self.decoders = decoders = OrderedDict([("no watermark", sd1)])
        for name, patched_decoder_ckpt in (
                ("weak", "models/checkpoint_000.pth.50000"),
                ("medium", "models/checkpoint_000.pth.150000"),
                ("strong", "models/checkpoint_000.pth.500000"),
                ("extreme", "models/checkpoint_000.pth.1500000")):
            sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
            msg = self.pipe.vae.load_state_dict(sd2, strict=False)
            print(f"loaded LDM decoder state_dict with message\n{msg}")
            print("you should check that the decoder keys are correctly matched")
            decoders[name] = sd2
        self.decoders = decoders

    def generate(self, mode, seed, prompt):
        generator = torch.Generator(device=device)
        #if seed:
        torch.manual_seed(seed)

        # load the patched VAE decoder
        sd = self.decoders[mode]
        self.pipe.vae.load_state_dict(sd, strict=False)

        output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
        return output.images[0] #{ "background": output.images[0], "layers": [], "composite": None }

    def attack_detect(self, img, jpeg_compression, downscale, crop, saturation):

        #img = img_edit["composite"]
        img = img.convert("RGB")
        
        # attack
        if downscale != 1:
            size = img.size
            size = (int(size[0] / downscale), int(size[1] / downscale))
            img = img.resize(size, Image.Resampling.LANCZOS)
        if crop != 0:
            width, height = img.size
            area = width * height
            log_rmin = math.log(0.5)
            log_rmax = math.log(2.0)
            for _ in range(10):
                target_area = area * (1 - crop)
                aspect_ratio = math.exp(random.random() * (log_rmax - log_rmin) + log_rmin)
                w = int(round(math.sqrt(target_area * aspect_ratio)))
                h = int(round(math.sqrt(target_area / aspect_ratio)))
                if 0 < w <= width and 0 < h <= height:
                    top = random.randint(0, height - h + 1)
                    left = random.randint(0, width - w + 1)
                    img = img.crop((left, top, left+w, top+h))
                    break

        converter = ImageEnhance.Color(img)
        img = converter.enhance(saturation)
        
        # send to detection API and apply JPEG compression attack
        mf = io.BytesIO()
        img.save(mf, format='JPEG', quality=jpeg_compression) # includes JPEG attack
        b64 = base64.b64encode(mf.getvalue())
        data = {
            'image': b64.decode('utf8')
        }
        
        headers = {}
        api_key = os.getenv('BZH_API_KEY')
        if api_key:
            headers['x-api-key'] = api_key
        response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
                                 json=data, headers=headers)
        response.raise_for_status()
        data = response.json()
        pvalue = data['p-value']

        mf.seek(0)
        img0 = Image.open(mf) # reload to show JPEG attack
        #result = "resolution = %dx%d  p-value = %e" % (img.size[0], img.size[1], pvalue))
        result = "No watermark detected."
        chances = int(1 / pvalue + 1)
        rpv = 10**int(math.log10(pvalue))
        if pvalue < 1e-3:
            result = "Watermark detected with low confidence (p-value<%.0e)" % rpv # (< 1/%d chances of being wrong)" % chances
        if pvalue < 1e-9:
            result = "Watermark detected with high confidence (p-value<%.0e)" % rpv # (< 1/%d chances of being wrong)" % chances
        return (img0, result)


def interface():
    prompt = "sailing ship in storm by Rembrandt"

    backend = BZHStableSignatureDemo()
    decoders = list(backend.decoders.keys())

    with gr.Blocks() as demo:
        gr.Markdown("""# Watermarked SDXL-Turbo demo
        This demo brought to you by [IMATAG](https://www.imatag.com/) presents watermarking of images generated via [StableDiffusion XL Turbo](https://huggingface.co/stabilityai/sdxl-turbo).
        Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
        the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
        this method with a demo version of [IMATAG](https://www.imatag.com/)'s in-house decoder. The watermarking system operates in zero-bit mode for improved robustness.""")

        with gr.Row():
            inp = gr.Textbox(label="Prompt", value=prompt)
            seed = gr.Number(label="Seed", precision=0)
            mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
        with gr.Row():
            btn1 = gr.Button("Generate")
        with gr.Row():
            watermarked_image = gr.Image(type="pil", width=512, height=512)           
            with gr.Column():
                gr.Markdown("""With these controls you may alter the generated image before detection. You may also upload your own edited image instead.""")
                downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
                crop = gr.Slider(0, 0.9, value=0, step=0.01, label="Random crop ratio")
                saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
                jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
                btn2 = gr.Button("Modify & Detect")
                with gr.Row():
                    attacked_image = gr.Image(type="pil", width=256)
                    detection_label = gr.Label(label="Detection info")
        btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
        btn2.click(fn=backend.attack_detect, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation], outputs=[attacked_image, detection_label], api_name="detect")

    return demo

if __name__ == '__main__':
    demo = interface()
    demo.launch()