24labsimages / app.py
erikbeltran's picture
Update app.py
3992596 verified
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
import hashlib
from PIL import Image
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
# Load LoRA weights
pipe.load_lora_weights(lora_path)
# Combine prompt with trigger word
full_prompt = f"{trigger_word} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=steps,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
# Generate or use provided hash for the image
if not custom_hash:
# Generate a hash if custom_hash is not provided
image_bytes = image.tobytes()
hash_object = hashlib.sha256(image_bytes)
image_hash = hash_object.hexdigest()
else:
image_hash = custom_hash
# Save the image with the hash as filename
image_path = f"{image_hash}.png"
image.save(image_path)
return image, image_hash
def run_lora(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
return generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash)
# Gradio interface
with gr.Blocks() as app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
with gr.Row():
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
with gr.Row():
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
with gr.Row():
custom_hash = gr.Textbox(label="Custom Hash (optional)", placeholder="Leave blank to auto-generate hash")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
output_hash = gr.Textbox(label="Image Hash", interactive=False)
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height, lora_path, trigger_word, steps, custom_hash],
outputs=[output_image, output_hash]
)
if __name__ == "__main__":
app.queue().launch(share=False)