import gradio as gr import os import torch import numpy as np device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') from transformers import AutoModelForImageClassification, BlipImageProcessor from diffusers import DiffusionPipeline, AutoencoderKL import torchvision.transforms as transforms from huggingface_hub import hf_hub_download from safetensors import safe_open from copy import deepcopy from collections import OrderedDict import requests import json from PIL import Image, ImageEnhance import base64 import io import random import math class BZHStableSignatureDemo(object): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda") # disable invisible-watermark self.pipe.watermark = None # save the original VAE decoders = OrderedDict([("no watermark", self.pipe.vae)]) # load the patched VAEs for name in ("weak", "medium", "strong", "extreme"): vae = AutoencoderKL.from_pretrained(f"imatag/stable-signature-bzh-sdxl-vae-{name}", torch_dtype=torch.float16).to("cuda") decoders[name] = vae self.decoders = decoders # load the proxy detector self.detector_image_processor = BlipImageProcessor.from_pretrained("imatag/stable-signature-bzh-detector-resnet18") self.detector_model = AutoModelForImageClassification.from_pretrained("imatag/stable-signature-bzh-detector-resnet18") calibration = hf_hub_download("imatag/stable-signature-bzh-detector-resnet18", filename="calibration.safetensors") with safe_open(calibration, framework="pt") as f: self.calibration_logits = f.get_tensor("logits") def generate(self, mode, seed, prompt): generator = torch.Generator(device=device) torch.manual_seed(seed) # load the patched VAE vae = self.decoders[mode] self.pipe.vae = vae output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil") return output.images[0] def attack(self, img, jpeg_compression, downscale, crop, saturation, brightness, contrast): img = img.convert("RGB") # attack if downscale != 1: size = img.size size = (int(size[0] / downscale), int(size[1] / downscale)) img = img.resize(size, Image.Resampling.LANCZOS) if crop != 0: width, height = img.size area = width * height log_rmin = math.log(0.5) log_rmax = math.log(2.0) for _ in range(10): target_area = area * (1 - crop) aspect_ratio = math.exp(random.random() * (log_rmax - log_rmin) + log_rmin) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: top = random.randint(0, height - h + 1) left = random.randint(0, width - w + 1) img = img.crop((left, top, left+w, top+h)) break converter = ImageEnhance.Color(img) img = converter.enhance(saturation) converter = ImageEnhance.Brightness(img) img = converter.enhance(brightness) converter = ImageEnhance.Contrast(img) img = converter.enhance(contrast) # JPEG attack mf = io.BytesIO() img.save(mf, format='JPEG', quality=jpeg_compression) filesize = mf.tell() mf.seek(0) img = Image.open(mf) image_info = "resolution: %dx%d" % img.size image_info += " JPEG file size: %d" % filesize return img, image_info def detect_api(self, img): # send to detection API and apply JPEG compression attack mf = io.BytesIO() img.save(mf, format='PNG') b64 = base64.b64encode(mf.getvalue()) data = { 'image': b64.decode('utf8') } headers = {} api_key = os.getenv('BZH_API_KEY') if api_key: headers['x-api-key'] = api_key response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect', json=data, headers=headers) response.raise_for_status() data = response.json() pvalue = data['p-value'] return pvalue def detect_proxy(self, img): img = img.convert("RGB") inputs = self.detector_image_processor(img, return_tensors="pt") with torch.no_grad(): logit = self.detector_model(**inputs).logits[...,0] pvalue = (1 + torch.searchsorted(self.calibration_logits, logit)) / self.calibration_logits.shape[0] pvalue = pvalue.item() return pvalue def detect(self, img, detection_method): if detection_method == "API": pvalue = self.detect_api(img) else: pvalue = self.detect_proxy(img) result = "No watermark detected." rpv = 10**int(math.log10(pvalue)) if pvalue < 1e-3: result = "Watermark detected with low confidence" # (p-value<%.0e)" % rpv if pvalue < 1e-6: result = "Watermark detected with high confidence" # (p-value<%.0e)" % rpv score = min(int(-math.log10(pvalue)), 10) #print("score = ", score) return { result: score/10 } def interface(): prompt = "sailing ship in storm by Rembrandt" backend = BZHStableSignatureDemo() decoders = list(backend.decoders.keys()) with gr.Blocks() as demo: gr.Markdown("""# Watermarked SDXL-Turbo demo This demo brought to you by [IMATAG](https://www.imatag.com/) presents watermarking of images generated via [StableDiffusion XL Turbo](https://huggingface.co/stabilityai/sdxl-turbo). Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/), the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined this method with a demo version of [IMATAG](https://www.imatag.com/)'s in-house decoder. The watermarking system operates in zero-bit mode for improved robustness.""") gr.Markdown("""## 1. Generate Select a watermarking strength and generate images with StableDiffusion-XL Turbo from prompt and seed as usual.""") with gr.Row(): inp = gr.Textbox(label="Prompt", value=prompt) seed = gr.Number(label="Seed", precision=0) mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium") with gr.Row(): btn1 = gr.Button("Generate") with gr.Row(): watermarked_image = gr.Image(type="pil", width=512, height=512, sources=[], interactive=False) gr.Markdown("""## 2. Edit With these controls you may alter the generated image before detection. You may also upload your own edited image instead.""") with gr.Row(): with gr.Column(): with gr.Row(): downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio") crop = gr.Slider(0, 0.9, value=0, step=0.01, label="Random crop ratio") with gr.Row(): brightness = gr.Slider(0, 2, value=1, step=0.1, label="Brightness") contrast = gr.Slider(0, 2, value=1, step=0.1, label="Contrast") with gr.Row(): saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation") jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality") btn2 = gr.Button("Edit") with gr.Row(): attacked_image = gr.Image(type="pil", width=512, sources=['upload', 'clipboard']) with gr.Row(): image_info_label = gr.Label(label="Image info") gr.Markdown("""## 3. Detect Detect the watermark on the altered image. Watermark may not be detected if the image is altered too strongly. You may choose to detect with our fast [proxy model](https://huggingface.co/imatag/stable-signature-bzh-detector-resnet18), or via API for improved robustness. """) with gr.Row(): detection_method = gr.Dropdown(choices=["proxy model", "API"], label="Detection method", value="proxy model") btn3 = gr.Button("Detect") with gr.Row(): detection_label = gr.Label(label="Detection info") btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate") btn2.click(fn=backend.attack, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation, brightness, contrast], outputs=[attacked_image, image_info_label], api_name="attack") btn3.click(fn=backend.detect, inputs=[attacked_image, detection_method], outputs=[detection_label], api_name="detect") return demo if __name__ == '__main__': demo = interface() demo.launch(server_name="0.0.0.0")