Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import random | |
from diffusers import DiffusionPipeline | |
import os | |
# Initialize models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# Initialize the base model and move it to GPU | |
base_model = "black-forest-labs/FLUX.1-dev" | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda") | |
# Load LoRA weights | |
pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA") | |
pipe.fuse_lora() | |
MAX_SEED = 2**32-1 | |
def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale): | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
joint_attention_kwargs={"scale": lora_scale}, | |
).images[0] | |
return image | |
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale) | |
return image, seed | |
custom_css = """ | |
.input-group, .output-group { | |
border: 1px solid #e0e0e0; | |
border-radius: 10px; | |
padding: 20px; | |
margin-bottom: 20px; | |
background-color: #f9f9f9; | |
} | |
.submit-btn { | |
background-color: #2980b9 !important; | |
color: white !important; | |
} | |
.submit-btn:hover { | |
background-color: #3498db !important; | |
} | |
""" | |
title = """<h1 align="center">FLUX Creativity LoRA</h1> | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app: | |
gr.HTML(title) | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here") | |
with gr.Row(): | |
generate_button = gr.Button("Generate", variant="primary") | |
with gr.Row(): | |
result = gr.Image(label="Generated Image") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) | |
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95) | |
inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale] | |
outputs = [result, seed] | |
generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs) | |
prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs) | |
app.launch(debug=True) |