import gradio as gr from gradio_client import Client client = Client("https://fffiloni-safety-checker-bot.hf.space/") import re import spaces from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel import torch from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker # load pipeline model_id = "stabilityai/stable-diffusion-xl-base-1.0" pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") # load finetuned model unet_id = "mhdang/dpo-sdxl-text2image-v1" unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) pipe.unet = unet pipe = pipe.to("cuda") pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() def safety_check(user_prompt): client = Client("https://fffiloni-safety-checker-bot.hf.space/--replicas/dvzmf/") response = client.predict( user_prompt, # str in 'User sent this' Textbox component api_name="/infer" ) return response @spaces.GPU(enable_queue=True) def infer(prompt): print(f""" —/n {prompt} """) is_safe = safety_check(prompt) words = is_safe.split() # Take the first word and convert it to a string variable status = str(words[0]) if status == "True" : raise gr.Error("Don't.") else: results = pipe(prompt, guidance_scale=7.5) #for i in range(len(results.images)): # if results.nsfw_content_detected[i]: # results.images[i] = Image.open("nsfw.png") return results.images[0] css = """ #col-container{ margin: 0 auto; max-width: 580px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML("""
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data.
""") with gr.Group(): with gr.Column(): prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head") submit_btn = gr.Button("Submit") result = gr.Image(label="DPO SDXL Result") gr.Examples( examples = [ "Dragon, digital art, by Greg Rutkowski", "Armored knight holding sword", "A flat roof villa near a river with black walls and huge windows", "A calm and peaceful office", "Pirate guinea pig" ], fn = infer, inputs = [ prompt_in ], outputs = [ result ] ) submit_btn.click( fn = infer, inputs = [ prompt_in ], outputs = [ result ] ) demo.queue().launch(show_api=False)