File size: 3,128 Bytes
5d92c6c
dd965ce
 
bccea8e
1faabe4
bccea8e
1509ea2
5d92c6c
 
4b15ff4
5d92c6c
 
 
bccea8e
5d92c6c
 
 
 
 
 
 
 
 
 
bccea8e
2ca71f0
bccea8e
 
 
 
 
 
784addb
5d92c6c
fed1b3d
 
 
 
bccea8e
f410c4d
199441c
 
 
 
 
 
 
 
020d346
 
bccea8e
 
 
 
 
 
5d92c6c
ff896ad
 
652aa24
5fb1dae
ff896ad
 
 
 
 
 
 
 
e2946ad
 
 
ff896ad
374aa25
 
 
 
ff896ad
 
57e1356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff896ad
 
 
 
 
 
 
 
 
 
fed1b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
from gradio_client import Client
client = Client("https://fffiloni-safety-checker-bot.hf.space/", hf_token=hf_token)
import re
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

# load pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")

# load finetuned model
unet_id = "mhdang/dpo-sdxl-text2image-v1"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet
pipe = pipe.to("cuda")

pipe.enable_model_cpu_offload() 
pipe.enable_vae_slicing()

def safety_check(user_prompt):
    
    response = client.predict(
        user_prompt,	# str  in 'User sent this' Textbox component
        api_name="/infer"
    )
    return response

@spaces.GPU(enable_queue=True)
def infer(prompt):
    print(f"""
    —/n
    {prompt}
    """)
    is_safe = safety_check(prompt)
    print(is_safe)
    
    match = re.search(r'\bYes\b', is_safe)
    
    if match:
        status = 'Yes'
    else:
        status = None
    
    if status == "Yes" :
        raise gr.Error("Don't ask for such things.")
    else:
        results = pipe(prompt, guidance_scale=7.5)
        #for i in range(len(results.images)):
        #    if results.nsfw_content_detected[i]:
        #        results.images[i] = Image.open("nsfw.png")
        return results.images[0]

css = """
#col-container{
    margin: 0 auto;
    max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML("""
        <h2 style="text-align: center;">
            SDXL Using Direct Preference Optimization
        </h2>
        <p style="text-align: center;">
            Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data.
        </p>
        """)
        with gr.Group():
            with gr.Column():
                prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head")
                submit_btn = gr.Button("Submit")
        result = gr.Image(label="DPO SDXL Result")

        gr.Examples(
            examples = [ 
                "Dragon, digital art, by Greg Rutkowski",
                "Armored knight holding sword",
                "A flat roof villa near a river with black walls and huge windows",
                "A calm and peaceful office",
                "Pirate guinea pig"
            ],
            fn = infer, 
            inputs = [
                prompt_in
            ],
            outputs = [
                result
            ]
        )

    submit_btn.click(
        fn = infer,
        inputs = [
            prompt_in
        ],
        outputs = [
            result
        ]
    )

demo.queue().launch(show_api=False)