File size: 6,171 Bytes
0b00c74
d2794b1
ca86cf6
8b44d8d
 
ca86cf6
 
 
 
464ec84
ca86cf6
 
 
 
 
 
 
 
 
 
 
2342c94
 
ca86cf6
 
 
 
 
464ec84
ca86cf6
464ec84
 
 
 
 
 
 
 
 
 
 
 
ca86cf6
 
 
 
2d11242
ca86cf6
464ec84
 
 
ca86cf6
 
464ec84
ca86cf6
2342c94
ca86cf6
4dfce87
e9d8edd
ca86cf6
 
 
 
fbe5687
464ec84
2342c94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca86cf6
 
 
 
 
 
 
 
 
 
 
 
 
2880299
ca86cf6
7cfd7ed
ca86cf6
 
 
 
 
8b44d8d
ca86cf6
 
464ec84
ca86cf6
8fe2131
ca86cf6
464ec84
c68dc14
464ec84
ca86cf6
8b44d8d
ca86cf6
 
8b44d8d
ca86cf6
 
d2794b1
ca86cf6
 
2d11242
ca86cf6
 
74ae0b4
0b00c74
ca86cf6
 
 
 
 
 
 
e9d8edd
ca86cf6
2342c94
ca86cf6
2342c94
ca86cf6
 
2342c94
ca86cf6
4188365
ca86cf6
 
2342c94
0b00c74
ca86cf6
0b00c74
ca86cf6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr

import os
import torch

import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from diffusers import DiffusionPipeline, AutoencoderKL
import torchvision.transforms as transforms

from copy import deepcopy
from collections import OrderedDict

import requests
import json

from PIL import Image, ImageEnhance
import base64
import io
import random
import math

class BZHStableSignatureDemo(object):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")

        # disable invisible-watermark
        self.pipe.watermark = None

        # save the original VAE
        decoders = OrderedDict([("no watermark", self.pipe.vae)])

        # load the patched VAEs
        for name in ("weak", "medium", "strong", "extreme"):
            vae = AutoencoderKL.from_pretrained(f"imatag/stable-signature-bzh-sdxl-vae-{name}", torch_dtype=torch.float16).to("cuda")
            decoders[name] = vae

        self.decoders = decoders

    def generate(self, mode, seed, prompt):
        generator = torch.Generator(device=device)
        torch.manual_seed(seed)

        # load the patched VAE
        vae = self.decoders[mode]
        self.pipe.vae = vae

        output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
        return output.images[0]

    def attack_detect(self, img, jpeg_compression, downscale, crop, saturation):

        img = img.convert("RGB")
        
        # attack
        if downscale != 1:
            size = img.size
            size = (int(size[0] / downscale), int(size[1] / downscale))
            img = img.resize(size, Image.Resampling.LANCZOS)

        if crop != 0:
            width, height = img.size
            area = width * height
            log_rmin = math.log(0.5)
            log_rmax = math.log(2.0)
            for _ in range(10):
                target_area = area * (1 - crop)
                aspect_ratio = math.exp(random.random() * (log_rmax - log_rmin) + log_rmin)
                w = int(round(math.sqrt(target_area * aspect_ratio)))
                h = int(round(math.sqrt(target_area / aspect_ratio)))
                if 0 < w <= width and 0 < h <= height:
                    top = random.randint(0, height - h + 1)
                    left = random.randint(0, width - w + 1)
                    img = img.crop((left, top, left+w, top+h))
                    break

        converter = ImageEnhance.Color(img)
        img = converter.enhance(saturation)
        
        # send to detection API and apply JPEG compression attack
        mf = io.BytesIO()
        img.save(mf, format='JPEG', quality=jpeg_compression) # includes JPEG attack
        b64 = base64.b64encode(mf.getvalue())
        data = {
            'image': b64.decode('utf8')
        }
        
        headers = {}
        api_key = os.getenv('BZH_API_KEY')
        if api_key:
            headers['x-api-key'] = api_key
        response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
                                 json=data, headers=headers)
        response.raise_for_status()
        data = response.json()
        pvalue = data['p-value']

        mf.seek(0)
        img0 = Image.open(mf) # reload to show JPEG attack

        result = "No watermark detected."
        rpv = 10**int(math.log10(pvalue))
        if pvalue < 1e-3:
            result = "Watermark detected with low confidence (p-value<%.0e)" % rpv
        if pvalue < 1e-9:
            result = "Watermark detected with high confidence (p-value<%.0e)" % rpv
        return (img0, result)

def interface():
    prompt = "sailing ship in storm by Rembrandt"

    backend = BZHStableSignatureDemo()
    decoders = list(backend.decoders.keys())

    with gr.Blocks() as demo:
        gr.Markdown("""# Watermarked SDXL-Turbo demo
        This demo brought to you by [IMATAG](https://www.imatag.com/) presents watermarking of images generated via [StableDiffusion XL Turbo](https://huggingface.co/stabilityai/sdxl-turbo).
        Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
        the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
        this method with a demo version of [IMATAG](https://www.imatag.com/)'s in-house decoder. The watermarking system operates in zero-bit mode for improved robustness.""")

        with gr.Row():
            inp = gr.Textbox(label="Prompt", value=prompt)
            seed = gr.Number(label="Seed", precision=0)
            mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
        with gr.Row():
            btn1 = gr.Button("Generate")
        with gr.Row():
            watermarked_image = gr.Image(type="pil", width=512, height=512)           
            with gr.Column():
                gr.Markdown("""With these controls you may alter the generated image before detection. You may also upload your own edited image instead.""")
                downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
                crop = gr.Slider(0, 0.9, value=0, step=0.01, label="Random crop ratio")
                saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
                jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
                btn2 = gr.Button("Modify & Detect")
                with gr.Row():
                    attacked_image = gr.Image(type="pil", width=256)
                    detection_label = gr.Label(label="Detection info")
        btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
        btn2.click(fn=backend.attack_detect, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation], outputs=[attacked_image, detection_label], api_name="detect")

    return demo

if __name__ == '__main__':
    demo = interface()
    demo.launch()