import spaces import gradio as gr import json import torch from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, AutoPipelineForInpainting from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images from diffusers.utils import load_image from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel, FluxControlNetInpaintPipeline, FluxInpaintPipeline from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi import os import copy import random import time import requests import pandas as pd import numpy as np from pathlib import Path from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger, description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image, get_trigger_word, enhance_prompt, set_control_union_image, get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en) from modutils import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json, download_my_lora_flux, get_all_lora_tupled_list, apply_lora_prompt_flux, update_loras_flux, update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, get_t2i_model_info, download_hf_file, save_image_history) from tagger.tagger import predict_tags_wd, compose_prompt_to_copy from tagger.fl2flux import predict_tags_fl2_flux #Load prompts for randomization df = pd.read_csv('prompts.csv', header=None) prompt_values = df.values.flatten() # Load LoRAs from JSON file with open('loras.json', 'r') as f: loras = json.load(f) # Initialize the base model base_model = models[0] controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union' #controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro' dtype = torch.bfloat16 #dtype = torch.float8_e4m3fn CACHE_MODEL = False device = "cuda" if torch.cuda.is_available() else "cpu" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN) good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN) pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN) if CACHE_MODEL: taef1.to(device) good_vae.to(device) pipe.to(device) pipe.transformer.to("cpu") pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe_ip = AutoPipelineForInpainting.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) controlnet_union = None controlnet = None last_model = models[0] last_cn_on = False #controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype) #controlnet = FluxMultiControlNetModel([controlnet_union]) #controlnet.config = controlnet_union.config MAX_SEED = 2**32-1 def unload_lora(): global pipe, pipe_i2i, pipe_ip try: #pipe.unfuse_lora() pipe.unload_lora_weights() #pipe_i2i.unfuse_lora() pipe_i2i.unload_lora_weights() #pipe_ip.unfuse_lora() pipe_ip.unload_lora_weights() except Exception as e: print(e) def download_file_mod(url, directory=os.getcwd()): path = download_hf_file(directory, url, hf_token=HF_TOKEN) if not path: raise Exception(f"Download error: {url}") return path # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union # https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux #@spaces.GPU() def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, progress=gr.Progress(track_tqdm=True)): global pipe, pipe_i2i, pipe_ip, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype safetensors_file = None single_file_base_model = single_file_base_models.get(model_type, models[0]) try: #if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True) if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update() unload_lora() pipe.to("cpu") pipe_i2i.to("cpu") pipe_ip.to("cpu") good_vae.to("cpu") taef1.to("cpu") if controlnet is not None: controlnet.to("cpu") if controlnet_union is not None: controlnet_union.to("cpu") clear_cache() if cn_on: progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}") print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}") controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=HF_TOKEN) controlnet = FluxMultiControlNetModel([controlnet_union]) controlnet.config = controlnet_union.config if ".safetensors" in repo_id: safetensors_file = download_file_mod(repo_id) transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model) if CACHE_MODEL: pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) else: pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN) pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) else: if CACHE_MODEL: transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, token=HF_TOKEN) pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN) pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) last_model = repo_id last_cn_on = cn_on progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}") print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}") else: progress(0, desc=f"Loading model: {repo_id}") print(f"Loading model: {repo_id}") if ".safetensors" in repo_id: safetensors_file = download_file_mod(repo_id) transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, token=HF_TOKEN) if CACHE_MODEL: pipe = DiffusionPipeline.from_pretrained(single_file_base_model, vae=taef1, transformer=transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) else: pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN) pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe_ip = AutoPipelineForInpainting.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) else: if CACHE_MODEL: transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, token=HF_TOKEN) pipe = DiffusionPipeline.from_pretrained(repo_id, vae=taef1, transformer=transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) else: pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN) pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) pipe_ip = AutoPipelineForInpainting.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) last_model = repo_id last_cn_on = cn_on progress(1, desc=f"Model loaded: {repo_id}") print(f"Model loaded: {repo_id}") except Exception as e: print(f"Model load Error: {repo_id} {e}") raise gr.Error(f"Model load Error: {repo_id} {e}") from e finally: if safetensors_file and Path(safetensors_file).exists(): Path(safetensors_file).unlink() return gr.update() change_base_model.zerogpu = True def is_repo_public(repo_id: str): api = HfApi() try: if api.repo_exists(repo_id=repo_id, token=False): return True else: return False except Exception as e: print(f"Error: Failed to connect {repo_id}. {e}") return False def calc_sigmas(num_inference_steps: int, sigmas_factor: float): sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = sigmas * sigmas_factor return sigmas class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") def download_file(url, directory=None): if directory is None: directory = os.getcwd() # Use current working directory if not specified # Get the filename from the URL filename = url.split('/')[-1] # Full path for the downloaded file filepath = os.path.join(directory, filename) # Download the file response = requests.get(url) response.raise_for_status() # Raise an exception for bad status codes # Write the content to the file with open(filepath, 'wb') as file: file.write(response.content) return filepath def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height): selected_index = evt.index selected_indices = selected_indices or [] if selected_index in selected_indices: selected_indices.remove(selected_index) else: if len(selected_indices) < 2: selected_indices.append(selected_index) else: gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.") return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update() selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨" lora_image_2 = lora2['image'] if selected_indices: last_selected_lora = loras_state[selected_indices[-1]] new_placeholder = f"Type a prompt for {last_selected_lora['title']}" else: new_placeholder = "Type a prompt" return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2 def remove_lora_1(selected_indices, loras_state): if len(selected_indices) >= 1: selected_indices.pop(0) selected_info_1 = "Select LoRA 1" selected_info_2 = "Select LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 def remove_lora_2(selected_indices, loras_state): if len(selected_indices) >= 2: selected_indices.pop(1) selected_info_1 = "Select LoRA 1" selected_info_2 = "Select LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 def randomize_loras(selected_indices, loras_state): if len(loras_state) < 2: raise gr.Error("Not enough LoRAs to randomize.") selected_indices = random.sample(range(len(loras_state)), 2) lora1 = loras_state[selected_indices[0]] lora2 = loras_state[selected_indices[1]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨" selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = lora1['image'] lora_image_2 = lora2['image'] random_prompt = random.choice(prompt_values) return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt def download_loras_images(loras_json_orig: list[dict]): api = HfApi(token=HF_TOKEN) loras_json = [] for lora in loras_json_orig: repo = lora.get("repo", None) if repo is None or not api.repo_exists(repo_id=repo, token=HF_TOKEN): print(f"LoRA '{repo}' is not exsit.") continue if "title" not in lora.keys() or "trigger_word" not in lora.keys() or "image" not in lora.keys(): title, _repo, _path, trigger_word, image_def = check_custom_model(repo) if "title" not in lora.keys(): lora["title"] = title if "trigger_word" not in lora.keys(): lora["trigger_word"] = trigger_word if "image" not in lora.keys(): lora["image"] = image_def image = lora.get("image", None) try: if not is_repo_public(repo) and image is not None and "http" in image and repo in image: image = download_file_mod(image) lora["image"] = image if image else "/home/user/app/custom.png" except Exception as e: print(f"Failed to download LoRA '{repo}''s image '{image if image else ''}'. {e}") lora["image"] = "/home/user/app/custom.png" loras_json.append(lora) return loras_json def add_custom_lora(custom_lora, selected_indices, current_loras, gallery): if custom_lora: try: title, repo, path, trigger_word, image = check_custom_model(custom_lora) if image is not None and "http" in image and not is_repo_public(repo) and repo in image: try: image = download_file_mod(image) except Exception as e: print(e) image = None print(f"Loaded custom LoRA: {repo}") existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None) if existing_item_index is None: if repo.endswith(".safetensors") and repo.startswith("http"): #repo = download_file(repo) repo = download_file_mod(repo) new_item = { "image": image if image else "/home/user/app/custom.png", "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word } print(f"New LoRA: {new_item}") existing_item_index = len(current_loras) current_loras.append(new_item) # Update gallery gallery_items = [(item["image"], item["title"]) for item in current_loras] # Update selected_indices if there's room if len(selected_indices) < 2: selected_indices.append(existing_item_index) else: gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.") # Update selected_info and images selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = current_loras[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨" lora_image_1 = lora1['image'] if lora1['image'] else None if len(selected_indices) >= 2: lora2 = current_loras[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨" lora_image_2 = lora2['image'] if lora2['image'] else None print("Finished adding custom LoRA") return ( current_loras, gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 ) except Exception as e: print(e) gr.Warning(str(e)) return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update() else: return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update() def remove_custom_lora(selected_indices, current_loras, gallery): if current_loras: custom_lora_repo = current_loras[-1]['repo'] # Remove from loras list current_loras = current_loras[:-1] # Remove from selected_indices if selected custom_lora_index = len(current_loras) if custom_lora_index in selected_indices: selected_indices.remove(custom_lora_index) # Update gallery gallery_items = [(item["image"], item["title"]) for item in current_loras] # Update selected_info and images selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = current_loras[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = current_loras[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return ( current_loras, gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 ) @spaces.GPU(duration=70) @torch.inference_mode() def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, sigmas_factor, cn_on, progress=gr.Progress(track_tqdm=True)): global pipe, taef1, good_vae, controlnet, controlnet_union try: good_vae.to(device) taef1.to(device) generator = torch.Generator(device=device).manual_seed(int(float(seed))) sigmas = calc_sigmas(steps, sigmas_factor) with calculateDuration("Generating image"): # Generate image modes, images, scales = get_control_params() if not cn_on or len(modes) == 0: pipe.to(device) pipe.vae = taef1 pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) progress(0, desc="Start Inference.") for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", good_vae=good_vae, sigmas=sigmas, ): yield img else: pipe.to(device) pipe.vae = good_vae if controlnet_union is not None: controlnet_union.to(device) if controlnet is not None: controlnet.to(device) pipe.enable_model_cpu_offload() progress(0, desc="Start Inference with ControlNet.") for img in pipe( prompt=prompt_mash, control_image=images, control_mode=modes, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, controlnet_conditioning_scale=scales, generator=generator, joint_attention_kwargs={"scale": 1.0}, sigmas=sigmas, ).images: yield img except Exception as e: print(e) raise gr.Error(f"Inference Error: {e}") from e @spaces.GPU(duration=70) @torch.inference_mode() def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, sigmas_factor, seed, cn_on, progress=gr.Progress(track_tqdm=True)): global pipe_i2i, pipe_ip, good_vae, controlnet, controlnet_union try: good_vae.to(device) generator = torch.Generator(device=device).manual_seed(int(float(seed))) image_input_path = image_input_path_dict['background'] mask_path = image_input_path_dict['layers'][0] sigmas = calc_sigmas(steps, sigmas_factor) with calculateDuration("Generating image"): # Generate image modes, images, scales = get_control_params() if not cn_on or len(modes) == 0: if is_inpaint: # Inpainting pipe_ip.to(device) pipe_ip.vae = good_vae image_input = load_image(image_input_path) mask_input = load_image(mask_path) if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor) progress(0, desc="Start Inpainting Inference.") final_image = pipe_ip( prompt=prompt_mash, image=image_input, mask_image=mask_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", #sigmas=sigmas, ).images[0] return final_image else: pipe_i2i.to(device) pipe_i2i.vae = good_vae image_input = load_image(image_input_path) progress(0, desc="Start I2I Inference.") final_image = pipe_i2i( prompt=prompt_mash, image=image_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", #sigmas=sigmas, ).images[0] return final_image else: if is_inpaint: # Inpainting pipe_ip.to(device) pipe_ip.vae = good_vae image_input = load_image(image_input_path) mask_input = load_image(mask_path) if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor) if controlnet_union is not None: controlnet_union.to(device) if controlnet is not None: controlnet.to(device) pipe_ip.enable_model_cpu_offload() progress(0, desc="Start Inpainting Inference with ControlNet.") final_image = pipe_ip( prompt=prompt_mash, control_image=images, control_mode=modes, image=image_input, mask_image=mask_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, controlnet_conditioning_scale=scales, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", #sigmas=sigmas, ).images[0] return final_image else: pipe_i2i.to(device) pipe_i2i.vae = good_vae image_input = load_image(image_input_path['background']) if controlnet_union is not None: controlnet_union.to(device) if controlnet is not None: controlnet.to(device) pipe_i2i.enable_model_cpu_offload() progress(0, desc="Start I2I Inference with ControlNet.") final_image = pipe_i2i( prompt=prompt_mash, control_image=images, control_mode=modes, image=image_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, controlnet_conditioning_scale=scales, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", #sigmas=sigmas, ).images[0] return final_image except Exception as e: print(e) raise gr.Error(f"I2I Inference Error: {e}") from e def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, sigmas_factor, loras_state, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)): global pipe, pipe_i2i, pipe_ip if not selected_indices and not is_valid_lora(lora_json): gr.Info("LoRA isn't selected.") # raise gr.Error("You must select a LoRA before proceeding.") progress(0, desc="Preparing Inference.") selected_loras = [loras_state[idx] for idx in selected_indices] if task_type == "Inpainting": is_inpaint = True is_i2i = True elif task_type == "Image-to-Image": is_inpaint = False is_i2i = True else: # "Text-to-Image" is_inpaint = False is_i2i = False if translate_on: prompt = translate_to_en(prompt) # Build the prompt with trigger words prepends = [] appends = [] for lora in selected_loras: trigger_word = lora.get('trigger_word', '') if trigger_word: if lora.get("trigger_position") == "prepend": prepends.append(trigger_word) else: appends.append(trigger_word) prompt_mash = " ".join(prepends + [prompt] + appends) print("Prompt Mash: ", prompt_mash) # # Unload previous LoRA weights with calculateDuration("Unloading LoRA"): unload_lora() print(pipe.get_active_adapters()) # print(pipe_i2i.get_active_adapters()) # print(pipe_ip.get_active_adapters()) # clear_cache() # # Build the prompt for External LoRAs prompt_mash = prompt_mash + get_model_trigger(last_model) lora_names = [] lora_weights = [] if is_valid_lora(lora_json): # Load External LoRA weights with calculateDuration("Loading External LoRA weights"): if is_inpaint: pipe_ip, lora_names, lora_weights = fuse_loras(pipe_ip, lora_json) elif is_i2i: pipe_i2i, lora_names, lora_weights = fuse_loras(pipe_i2i, lora_json) else: pipe, lora_names, lora_weights = fuse_loras(pipe, lora_json) trigger_word = get_trigger_word(lora_json) prompt_mash = f"{prompt_mash} {trigger_word}" print("Prompt Mash: ", prompt_mash) # # Load LoRA weights with respective scales if selected_indices: with calculateDuration("Loading LoRA weights"): for idx, lora in enumerate(selected_loras): lora_name = f"lora_{idx}" lora_names.append(lora_name) print(f"Lora Name: {lora_name}") lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2) lora_path = lora['repo'] weight_name = lora.get("weights") print(f"Lora Path: {lora_path}") if is_inpaint: pipe_ip.load_lora_weights( lora_path, weight_name=weight_name if weight_name else None, low_cpu_mem_usage=False, adapter_name=lora_name, token=HF_TOKEN ) elif is_i2i: pipe_i2i.load_lora_weights( lora_path, weight_name=weight_name if weight_name else None, low_cpu_mem_usage=False, adapter_name=lora_name, token=HF_TOKEN ) else: pipe.load_lora_weights( lora_path, weight_name=weight_name if weight_name else None, low_cpu_mem_usage=False, adapter_name=lora_name, token=HF_TOKEN ) print("Loaded LoRAs:", lora_names) if selected_indices or is_valid_lora(lora_json): if is_inpaint: pipe_ip.set_adapters(lora_names, adapter_weights=lora_weights) elif is_i2i: pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights) else: pipe.set_adapters(lora_names, adapter_weights=lora_weights) print(pipe.get_active_adapters()) # print(pipe_i2i.get_active_adapters()) # print(pipe_ip.get_active_adapters()) # # Set random seed for reproducibility with calculateDuration("Randomizing seed"): if randomize_seed: seed = random.randint(0, MAX_SEED) # Generate image progress(0, desc="Running Inference.") if is_i2i: final_image = generate_image_to_image(prompt_mash, image_input, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, sigmas_factor, seed, cn_on) yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False) else: image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, sigmas_factor, cn_on) # Consume the generator to get the final image final_image = None step_counter = 0 for image in image_generator: step_counter+=1 final_image = image progress_bar = f'