AutismMix / app.py
gokaygokay's picture
Update app.py
1d05d2f verified
raw
history blame
3.16 kB
import spaces
import gradio as gr
import torch
import random
from diffusers import DiffusionPipeline
import os
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Initialize the base model and move it to GPU
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda")
# Load LoRA weights
pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA")
pipe.fuse_lora()
MAX_SEED = 2**32-1
@spaces.GPU(duration=75)
def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale):
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale)
return image, seed
custom_css = """
.input-group, .output-group {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: #f9f9f9;
}
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
title = """<h1 align="center">FLUX Creativity LoRA</h1>
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app:
gr.HTML(title)
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here")
with gr.Row():
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
result = gr.Image(label="Generated Image")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]
outputs = [result, seed]
generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs)
prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs)
app.launch(debug=True)