|
import gradio as gr |
|
import spaces |
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel |
|
import torch |
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") |
|
|
|
|
|
unet_id = "mhdang/dpo-sdxl-text2image-v1" |
|
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) |
|
pipe.unet = unet |
|
pipe = pipe.to("cuda") |
|
|
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_vae_slicing() |
|
|
|
@spaces.GPU |
|
def infer(prompt): |
|
image = pipe(prompt, guidance_scale=7.5).images[0] |
|
return image |
|
|
|
css = """ |
|
#col-container{ |
|
margin: 0 auto; |
|
max-width: 580px; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(""" |
|
<h2 style="text-align: center;"> |
|
SDXL Using Direct Preference Optimization |
|
</h2> |
|
<p style="text-align: center;"> |
|
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data. |
|
</p> |
|
""") |
|
with gr.Group(): |
|
with gr.Column(): |
|
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head") |
|
submit_btn = gr.Button("Submit") |
|
result = gr.Image(label="DPO SDXL Result") |
|
|
|
gr.Examples( |
|
examples = [ |
|
"Dragon, digital art, by Greg Rutkowski", |
|
"Armored knight holding sword", |
|
"A flat roof villa near a river with black walls and huge windows", |
|
"A calm and peaceful office", |
|
"Pirate guinea pig" |
|
], |
|
fn = infer, |
|
inputs = [ |
|
prompt_in |
|
], |
|
outputs = [ |
|
result |
|
] |
|
) |
|
|
|
submit_btn.click( |
|
fn = infer, |
|
inputs = [ |
|
prompt_in |
|
], |
|
outputs = [ |
|
result |
|
] |
|
) |
|
|
|
demo.queue().launch() |