File size: 3,128 Bytes
5d92c6c dd965ce bccea8e 1faabe4 bccea8e 1509ea2 5d92c6c 4b15ff4 5d92c6c bccea8e 5d92c6c bccea8e 2ca71f0 bccea8e 784addb 5d92c6c fed1b3d bccea8e f410c4d 199441c 020d346 bccea8e 5d92c6c ff896ad 652aa24 5fb1dae ff896ad e2946ad ff896ad 374aa25 ff896ad 57e1356 ff896ad fed1b3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
from gradio_client import Client
client = Client("https://fffiloni-safety-checker-bot.hf.space/", hf_token=hf_token)
import re
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# load pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
# load finetuned model
unet_id = "mhdang/dpo-sdxl-text2image-v1"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
def safety_check(user_prompt):
response = client.predict(
user_prompt, # str in 'User sent this' Textbox component
api_name="/infer"
)
return response
@spaces.GPU(enable_queue=True)
def infer(prompt):
print(f"""
—/n
{prompt}
""")
is_safe = safety_check(prompt)
print(is_safe)
match = re.search(r'\bYes\b', is_safe)
if match:
status = 'Yes'
else:
status = None
if status == "Yes" :
raise gr.Error("Don't ask for such things.")
else:
results = pipe(prompt, guidance_scale=7.5)
#for i in range(len(results.images)):
# if results.nsfw_content_detected[i]:
# results.images[i] = Image.open("nsfw.png")
return results.images[0]
css = """
#col-container{
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">
SDXL Using Direct Preference Optimization
</h2>
<p style="text-align: center;">
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data.
</p>
""")
with gr.Group():
with gr.Column():
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head")
submit_btn = gr.Button("Submit")
result = gr.Image(label="DPO SDXL Result")
gr.Examples(
examples = [
"Dragon, digital art, by Greg Rutkowski",
"Armored knight holding sword",
"A flat roof villa near a river with black walls and huge windows",
"A calm and peaceful office",
"Pirate guinea pig"
],
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
submit_btn.click(
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
demo.queue().launch(show_api=False) |