import os import gradio as gr import json import logging import torch from PIL import Image import spaces from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, FluxImg2ImgPipeline from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images from diffusers.utils import load_image from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download import copy import random import time # Load LoRAs from JSON file with open('loras.json', 'r') as f: loras = json.load(f) # Initialize the base model dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device) pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) ipe_i2i = FluxImg2ImgPipeline.from_pretrained( base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype ) MAX_SEED = 2**32-1 pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") def update_selection(evt: gr.SelectData, width, height): selected_lora = loras[evt.index] new_placeholder = f"Type a prompt for {selected_lora['title']}" lora_repo = selected_lora["repo"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨" if "aspect" in selected_lora: if selected_lora["aspect"] == "portrait": width = 768 height = 1024 elif selected_lora["aspect"] == "landscape": width = 1024 height = 768 else: width = 1024 height = 1024 return ( gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, ) @spaces.GPU(duration=70) def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress): pipe.to("cuda") generator = torch.Generator(device="cuda").manual_seed(seed) with calculateDuration("Generating image"): # Generate image for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", good_vae=good_vae, ): yield img @spaces.GPU(duration=70) def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed): generator = torch.Generator(device="cuda").manual_seed(seed) pipe_i2i.to("cuda") image_input = load_image(image_input_path) final_image = pipe_i2i( prompt=prompt_mash, image=image_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", ).images[0] return final_image def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): if selected_index is None: raise gr.Error("You must select a LoRA before proceeding.") selected_lora = loras[selected_index] lora_path = selected_lora["repo"] trigger_word = selected_lora["trigger_word"] if trigger_word: if "trigger_position" in selected_lora: if selected_lora["trigger_position"] == "prepend": prompt_mash = f"{trigger_word} {prompt}" else: prompt_mash = f"{prompt} {trigger_word}" else: prompt_mash = f"{trigger_word} {prompt}" else: prompt_mash = prompt with calculateDuration("Unloading LoRA"): pipe.unload_lora_weights() pipe_i2i.unload_lora_weights() # Load LoRA weights with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"): if image_input is not None: if "weights" in selected_lora: pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"]) else: pipe_i2i.load_lora_weights(lora_path) else: if "weights" in selected_lora: pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"]) else: pipe.load_lora_weights(lora_path) # Set random seed for reproducibility with calculateDuration("Randomizing seed"): if randomize_seed: seed = random.randint(0, MAX_SEED) if image_input is not None: final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed) yield final_image, seed, gr.update(visible=False) else: image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress) # Consume the generator to get the final image final_image = None step_counter = 0 for image in image_generator: step_counter += 1 final_image = image progress_bar = f'Generating image... Step {step_counter}/{steps}' yield image, seed, gr.update(visible=True, value=progress_bar) yield final_image, seed, gr.update(visible=False) # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Awaken Ones' Lora Previews") gr.Markdown("Select a LoRA model from the gallery below to get started!") with gr.Row(): gallery = gr.Gallery( value=[lora["image"] for lora in loras], label="LoRA Gallery", show_label=False, elem_id="gallery", columns=[5], rows=[3], object_fit="contain", height="auto", ) with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="Type your prompt here...", show_label=True, ) image_input = gr.Image(type="filepath", label="Image Input (Optional)") with gr.Row(): generate = gr.Button("Generate", variant="primary") cancel = gr.Button("Cancel") with gr.Row(): with gr.Column(scale=4): result = gr.Image(label="Result", show_label=False, elem_id="result") with gr.Column(scale=1): seed_output = gr.Number(label="Seed", interactive=False) with gr.Row(): with gr.Column(): steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Steps") cfg_scale = gr.Slider(minimum=1, maximum=20, value=3.5, step=0.1, label="CFG Scale") lora_scale = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.05, label="LoRA Scale") with gr.Column(): width = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width") height = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height") image_strength = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.05, label="Image Strength") with gr.Row(): randomize_seed = gr.Checkbox(label="Randomize seed", value=True) seed_input = gr.Number(label="Seed", value=0, interactive=True, visible=False) selected_lora = gr.Markdown("### No LoRA selected") progress_bar = gr.Markdown(visible=False) # Event handlers gallery.select(update_selection, [width, height], [prompt, selected_lora, gr.State(), width, height]) randomize_seed.change(lambda x: gr.update(visible=not x), randomize_seed, seed_input) generate.click(run_lora, inputs=[prompt, image_input, image_strength, cfg_scale, steps, gr.State(), randomize_seed, seed_input, width, height, lora_scale], outputs=[result, seed_output, progress_bar]) cancel.click(lambda: None, None, None, cancels=[generate]) demo.queue().launch()