Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
hf_token = os.environ.get("HF_TOKEN") | |
from gradio_client import Client | |
client = Client("https://fffiloni-safety-checker-bot.hf.space/", hf_token=hf_token) | |
import re | |
import spaces | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel | |
import torch | |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
# load pipeline | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") | |
# load finetuned model | |
unet_id = "mhdang/dpo-sdxl-text2image-v1" | |
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) | |
pipe.unet = unet | |
pipe = pipe.to("cuda") | |
pipe.enable_model_cpu_offload() | |
pipe.enable_vae_slicing() | |
def safety_check(user_prompt): | |
response = client.predict( | |
user_prompt, # str in 'User sent this' Textbox component | |
api_name="/infer" | |
) | |
return response | |
def infer(prompt): | |
print(f""" | |
—/n | |
{prompt} | |
""") | |
is_safe = safety_check(prompt) | |
print(is_safe) | |
match = re.search(r'\bYes\b', is_safe) | |
if match: | |
status = 'Yes' | |
else: | |
status = None | |
if status == "Yes" : | |
raise gr.Error("Don't ask for such things.") | |
else: | |
results = pipe(prompt, guidance_scale=7.5) | |
#for i in range(len(results.images)): | |
# if results.nsfw_content_detected[i]: | |
# results.images[i] = Image.open("nsfw.png") | |
return results.images[0] | |
css = """ | |
#col-container{ | |
margin: 0 auto; | |
max-width: 580px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(""" | |
<h2 style="text-align: center;"> | |
SDXL Using Direct Preference Optimization | |
</h2> | |
<p style="text-align: center;"> | |
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data. | |
</p> | |
""") | |
with gr.Group(): | |
with gr.Column(): | |
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head") | |
submit_btn = gr.Button("Submit") | |
result = gr.Image(label="DPO SDXL Result") | |
gr.Examples( | |
examples = [ | |
"Dragon, digital art, by Greg Rutkowski", | |
"Armored knight holding sword", | |
"A flat roof villa near a river with black walls and huge windows", | |
"A calm and peaceful office", | |
"Pirate guinea pig" | |
], | |
fn = infer, | |
inputs = [ | |
prompt_in | |
], | |
outputs = [ | |
result | |
] | |
) | |
submit_btn.click( | |
fn = infer, | |
inputs = [ | |
prompt_in | |
], | |
outputs = [ | |
result | |
] | |
) | |
demo.queue().launch(show_api=False) |