sdxl-dpo / app.py
fffiloni's picture
Update app.py
1509ea2
raw
history blame
2.16 kB
import gradio as gr
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch
# load pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
# load finetuned model
unet_id = "mhdang/dpo-sdxl-text2image-v1"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
@spaces.GPU
def infer(prompt):
image = pipe(prompt, guidance_scale=7.5).images[0]
return image
css = """
#col-container{
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">
SDXL Using Direct Preference Optimization
</h2>
<p style="text-align: center;">
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data.
</p>
""")
with gr.Group():
with gr.Column():
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head")
submit_btn = gr.Button("Submit")
result = gr.Image(label="DPO SDXL Result")
gr.Examples(
examples = [
"Dragon, digital art, by Greg Rutkowski",
"Armored knight holding sword",
"A flat roof villa near a river with black walls and huge windows",
"A calm and peaceful office",
"Pirate guinea pig"
],
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
submit_btn.click(
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
demo.queue().launch()