|
import gradio as gr |
|
from gradio_client import Client |
|
client = Client("https://fffiloni-safety-checker-bot.hf.space/") |
|
import re |
|
import spaces |
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel |
|
import torch |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") |
|
|
|
|
|
unet_id = "mhdang/dpo-sdxl-text2image-v1" |
|
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) |
|
pipe.unet = unet |
|
pipe = pipe.to("cuda") |
|
|
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_vae_slicing() |
|
|
|
def safety_check(user_prompt): |
|
client = Client("https://fffiloni-safety-checker-bot.hf.space/--replicas/dvzmf/") |
|
response = client.predict( |
|
user_prompt, |
|
api_name="/infer" |
|
) |
|
return response |
|
|
|
@spaces.GPU(enable_queue=True) |
|
def infer(prompt): |
|
print(f""" |
|
—/n |
|
{prompt} |
|
""") |
|
is_safe = safety_check(prompt) |
|
words = is_safe.split() |
|
|
|
status = str(words[0]) |
|
if status == "True" : |
|
gr.Error("Don't.") |
|
else: |
|
results = pipe(prompt, guidance_scale=7.5) |
|
|
|
|
|
|
|
return results.images[0] |
|
|
|
css = """ |
|
#col-container{ |
|
margin: 0 auto; |
|
max-width: 580px; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(""" |
|
<h2 style="text-align: center;"> |
|
SDXL Using Direct Preference Optimization |
|
</h2> |
|
<p style="text-align: center;"> |
|
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data. |
|
</p> |
|
""") |
|
with gr.Group(): |
|
with gr.Column(): |
|
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head") |
|
submit_btn = gr.Button("Submit") |
|
result = gr.Image(label="DPO SDXL Result") |
|
|
|
gr.Examples( |
|
examples = [ |
|
"Dragon, digital art, by Greg Rutkowski", |
|
"Armored knight holding sword", |
|
"A flat roof villa near a river with black walls and huge windows", |
|
"A calm and peaceful office", |
|
"Pirate guinea pig" |
|
], |
|
fn = infer, |
|
inputs = [ |
|
prompt_in |
|
], |
|
outputs = [ |
|
result |
|
] |
|
) |
|
|
|
submit_btn.click( |
|
fn = infer, |
|
inputs = [ |
|
prompt_in |
|
], |
|
outputs = [ |
|
result |
|
] |
|
) |
|
|
|
demo.queue().launch(show_api=False) |