import spaces import gradio as gr import json import torch from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, AutoPipelineForInpainting, GGUFQuantizationConfig from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images from diffusers.utils import load_image from diffusers import (FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel, FluxControlNetInpaintPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxControlPipeline) from transformers import T5EncoderModel from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi import os import copy import random import time import requests import pandas as pd import numpy as np from pathlib import Path from env import models, models_dev, models_schnell, models_fill, models_canny, models_depth, models_edit, num_loras, num_cns, HF_TOKEN, single_file_base_models from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger, description_ui, compose_lora_json, is_valid_lora, fuse_loras, turbo_loras, save_image, preprocess_i2i_image, get_trigger_word, enhance_prompt, set_control_union_image, get_canny_image, get_depth_image, get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en) from modutils import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json, download_my_lora_flux, get_all_lora_tupled_list, apply_lora_prompt_flux, update_loras_flux, update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, get_t2i_model_info, download_hf_file, save_image_history) from tagger.tagger import predict_tags_wd, compose_prompt_to_copy from tagger.fl2flux import predict_tags_fl2_flux #Load prompts for randomization df = pd.read_csv('prompts.csv', header=None) prompt_values = df.values.flatten() # Load LoRAs from JSON file with open('loras.json', 'r') as f: loras = json.load(f) # Initialize the base model base_model = models[0] controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union' #controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro' dtype = torch.bfloat16 #dtype = torch.float8_e4m3fn device = "cuda" if torch.cuda.is_available() else "cpu" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN) good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN) pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN) pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN) controlnet_union = None controlnet = None last_model = models[0] last_cn_on = False last_task = "Text-to-Image" last_dtype_str = "BF16" #controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype) #controlnet = FluxMultiControlNetModel([controlnet_union]) #controlnet.config = controlnet_union.config MAX_SEED = 2**32-1 TASK_TYPE_T2I = ["Text-to-Image"] TASK_TYPE_CONTROL = ["Canny", "Depth", "Edit"] TASK_TYPE_I2I = ["Image-to-Image", "Inpainting", "Flux Fill"] + TASK_TYPE_CONTROL def unload_lora(): global pipe, pipe_i2i try: #pipe.unfuse_lora() pipe.unload_lora_weights() #pipe_i2i.unfuse_lora() pipe_i2i.unload_lora_weights() except Exception as e: print(e) def download_file_mod(url, directory=os.getcwd()): path = download_hf_file(directory, url, hf_token=HF_TOKEN) if not path: raise Exception(f"Download error: {url}") return path def print_progress(desc: str, proceed: float=0.0, progress=gr.Progress(track_tqdm=True)): progress(proceed, desc=desc) print(desc) #@spaces.GPU(duration=30) def load_quantized_control(control_repo: str, dtype, hf_token): transformer = FluxTransformer2DModel.from_pretrained(control_repo, subfolder="transformer", torch_dtype=dtype, token=hf_token).to("cpu") text_encoder_2 = T5EncoderModel.from_pretrained(control_repo, subfolder="text_encoder_2", torch_dtype=dtype, token=hf_token).to("cpu") return transformer, text_encoder_2 def load_pipeline(pipe, pipe_i2i, repo_id: str, cn_on: bool, model_type: str, task: str, dtype_str: str, hf_token: str, progress=gr.Progress(track_tqdm=True)): try: controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union' if task == "Flux Fill" or repo_id in models_fill: model_type = "fill" if repo_id in set(models_dev + models_schnell): repo_id = models_fill[0] if dtype_str == "BF16": dtype = torch.bfloat16 else: dtype = torch.bfloat16 single_file_base_model = single_file_base_models.get(model_type, models[0]) kwargs = {} transformer_model = FluxTransformer2DModel t5_model = T5EncoderModel if task == "Flux Fill": pipeline = FluxFillPipeline pipeline_i2i = FluxFillPipeline elif task in TASK_TYPE_CONTROL: pipeline = DiffusionPipeline pipeline_i2i = FluxControlPipeline elif cn_on: # with ControlNet print_progress(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}", 0, progress) controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=hf_token) controlnet = FluxMultiControlNetModel([controlnet_union]) controlnet.config = controlnet_union.config pipeline = FluxControlNetPipeline pipeline_i2i = FluxControlNetInpaintPipeline if task == "Inpainting" else FluxControlNetImg2ImgPipeline kwargs["controlnet"] = controlnet else: # without ControlNet print_progress(f"Loading model: {repo_id}", 0, progress) pipeline = DiffusionPipeline pipeline_i2i = AutoPipelineForInpainting if task == "Inpainting" else AutoPipelineForImage2Image if task in TASK_TYPE_CONTROL: # FluxControlPipeline if task == "Canny": control_repo = models_canny[0] elif task == "Depth": control_repo = models_depth[0] elif task == "Edit": control_repo = models_edit[0] if task == "Edit": transformer = transformer_model.from_pretrained(control_repo, torch_dtype=dtype, token=hf_token) text_encoder_2 = t5_model.from_pretrained(models_dev[0], subfolder="text_encoder_2", torch_dtype=dtype, token=hf_token) else: transformer = transformer_model.from_pretrained(control_repo, subfolder="transformer", torch_dtype=dtype, token=hf_token) text_encoder_2 = t5_model.from_pretrained(control_repo, subfolder="text_encoder_2", torch_dtype=dtype, token=hf_token) #transformer, text_encoder_2 = load_quantized_control(control_repo, dtype, hf_token) pipe = pipeline.from_pretrained(models_dev[0], transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, token=hf_token) pipe_i2i = pipeline_i2i.from_pipe(pipe, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype) elif ".safetensors" in repo_id or ".gguf" in repo_id: # from single file file_url = repo_id.replace("/resolve/main/", "/blob/main/").replace("?download=true", "") if ".gguf" in file_url: transformer = transformer_model.from_single_file(file_url, subfolder="transformer", quantization_config=GGUFQuantizationConfig(compute_dtype=dtype), torch_dtype=dtype, config=single_file_base_model) else: transformer = transformer_model.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model) if not transformer: transformer = transformer_model.from_pretrained(single_file_base_model, subfolder="transformer", torch_dtype=dtype, token=hf_token) pipe = pipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=hf_token, **kwargs) pipe_i2i = pipeline_i2i.from_pretrained(single_file_base_model, vae=pipe.vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=hf_token, **kwargs) else: # from diffusers repo pipe = pipeline.from_pretrained(repo_id, torch_dtype=dtype, token=hf_token, **kwargs) pipe_i2i = pipeline_i2i.from_pretrained(repo_id, vae=pipe.vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=hf_token, **kwargs) if cn_on: print_progress(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}", 1, progress) else: print_progress(f"Model loaded: {repo_id}", 1, progress) except Exception as e: print(e) gr.Warning(f"Failed to load pipeline: {e}") finally: return pipe, pipe_i2i #load_pipeline.zerogpu = True # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union # https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux #@spaces.GPU() def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, task: str, dtype_str: str, progress=gr.Progress(track_tqdm=True)): global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, last_task, last_dtype_str, dtype try: if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on and task == last_task and dtype_str == last_dtype_str)\ or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update() # and not ".gguf" in repo_id unload_lora() if pipe is not None: pipe.to("cpu") if pipe_i2i is not None: pipe_i2i.to("cpu") if good_vae is not None: good_vae.to("cpu") if taef1 is not None: taef1.to("cpu") if controlnet is not None: controlnet.to("cpu") if controlnet_union is not None: controlnet_union.to("cpu") pipe, pipe_i2i = load_pipeline(pipe, pipe_i2i, repo_id, cn_on, model_type, task, dtype_str, HF_TOKEN, progress) clear_cache() last_model = repo_id last_cn_on = cn_on last_task = task last_dtype_str = dtype_str except Exception as e: print(f"Model load Error: {repo_id} {e}") raise gr.Error(f"Model load Error: {repo_id} {e}") from e return gr.update() change_base_model.zerogpu = True def is_repo_public(repo_id: str): api = HfApi() try: if api.repo_exists(repo_id=repo_id, token=False): return True else: return False except Exception as e: print(f"Error: Failed to connect {repo_id}. {e}") return False def calc_sigmas(num_inference_steps: int, sigmas_factor: float): sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = sigmas * sigmas_factor return sigmas class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") def download_file(url, directory=None): if directory is None: directory = os.getcwd() # Use current working directory if not specified # Get the filename from the URL filename = url.split('/')[-1] # Full path for the downloaded file filepath = os.path.join(directory, filename) # Download the file response = requests.get(url) response.raise_for_status() # Raise an exception for bad status codes # Write the content to the file with open(filepath, 'wb') as file: file.write(response.content) return filepath def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height): selected_index = evt.index selected_indices = selected_indices or [] if selected_index in selected_indices: selected_indices.remove(selected_index) else: if len(selected_indices) < 2: selected_indices.append(selected_index) else: gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.") return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update() selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨" lora_image_2 = lora2['image'] if selected_indices: last_selected_lora = loras_state[selected_indices[-1]] new_placeholder = f"Type a prompt for {last_selected_lora['title']}" else: new_placeholder = "Type a prompt" return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2 def remove_lora_1(selected_indices, loras_state): if len(selected_indices) >= 1: selected_indices.pop(0) selected_info_1 = "Select LoRA 1" selected_info_2 = "Select LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 def remove_lora_2(selected_indices, loras_state): if len(selected_indices) >= 2: selected_indices.pop(1) selected_info_1 = "Select LoRA 1" selected_info_2 = "Select LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = loras_state[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = loras_state[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 def randomize_loras(selected_indices, loras_state): if len(loras_state) < 2: raise gr.Error("Not enough LoRAs to randomize.") selected_indices = random.sample(range(len(loras_state)), 2) lora1 = loras_state[selected_indices[0]] lora2 = loras_state[selected_indices[1]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨" selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = lora1['image'] lora_image_2 = lora2['image'] random_prompt = random.choice(prompt_values) return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt def download_loras_images(loras_json_orig: list[dict]): api = HfApi(token=HF_TOKEN) loras_json = [] for lora in loras_json_orig: repo = lora.get("repo", None) if repo is None or not api.repo_exists(repo_id=repo, token=HF_TOKEN): print(f"LoRA '{repo}' is not exsit.") continue if "title" not in lora.keys() or "trigger_word" not in lora.keys() or "image" not in lora.keys(): title, _repo, _path, trigger_word, image_def = check_custom_model(repo) if "title" not in lora.keys(): lora["title"] = title if "trigger_word" not in lora.keys(): lora["trigger_word"] = trigger_word if "image" not in lora.keys(): lora["image"] = image_def image = lora.get("image", None) try: if not is_repo_public(repo) and image is not None and "http" in image and repo in image: image = download_file_mod(image) lora["image"] = image if image else "/home/user/app/custom.png" except Exception as e: print(f"Failed to download LoRA '{repo}''s image '{image if image else ''}'. {e}") lora["image"] = "/home/user/app/custom.png" loras_json.append(lora) return loras_json def add_custom_lora(custom_lora, selected_indices, current_loras, gallery): if custom_lora: try: title, repo, path, trigger_word, image = check_custom_model(custom_lora) if image is not None and "http" in image and not is_repo_public(repo) and repo in image: try: image = download_file_mod(image) except Exception as e: print(e) image = None print(f"Loaded custom LoRA: {repo}") existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None) if existing_item_index is None: if repo.endswith(".safetensors") and repo.startswith("http"): #repo = download_file(repo) repo = download_file_mod(repo) new_item = { "image": image if image else "/home/user/app/custom.png", "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word } print(f"New LoRA: {new_item}") existing_item_index = len(current_loras) current_loras.append(new_item) # Update gallery gallery_items = [(item["image"], item["title"]) for item in current_loras] # Update selected_indices if there's room if len(selected_indices) < 2: selected_indices.append(existing_item_index) else: gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.") # Update selected_info and images selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = current_loras[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨" lora_image_1 = lora1['image'] if lora1['image'] else None if len(selected_indices) >= 2: lora2 = current_loras[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨" lora_image_2 = lora2['image'] if lora2['image'] else None print("Finished adding custom LoRA") return ( current_loras, gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 ) except Exception as e: print(e) gr.Warning(str(e)) return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update() else: return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update() def remove_custom_lora(selected_indices, current_loras, gallery): if current_loras: custom_lora_repo = current_loras[-1]['repo'] # Remove from loras list current_loras = current_loras[:-1] # Remove from selected_indices if selected custom_lora_index = len(current_loras) if custom_lora_index in selected_indices: selected_indices.remove(custom_lora_index) # Update gallery gallery_items = [(item["image"], item["title"]) for item in current_loras] # Update selected_info and images selected_info_1 = "Select a LoRA 1" selected_info_2 = "Select a LoRA 2" lora_scale_1 = 1.15 lora_scale_2 = 1.15 lora_image_1 = None lora_image_2 = None if len(selected_indices) >= 1: lora1 = current_loras[selected_indices[0]] selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨" lora_image_1 = lora1['image'] if len(selected_indices) >= 2: lora2 = current_loras[selected_indices[1]] selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨" lora_image_2 = lora2['image'] return ( current_loras, gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2 ) @spaces.GPU(duration=70) @torch.inference_mode() def generate_image(prompt_mash: str, steps: int, seed: int, cfg_scale: float, width: int, height: int, sigmas_factor: float, cn_on: bool, progress=gr.Progress(track_tqdm=True)): global pipe, taef1, good_vae, controlnet, controlnet_union try: good_vae.to(device) taef1.to(device) generator = torch.Generator(device=device).manual_seed(int(float(seed))) sigmas = calc_sigmas(steps, sigmas_factor) with calculateDuration("Generating image"): # Generate image modes, images, scales = get_control_params() if not cn_on or len(modes) == 0: # without ControlNet pipe.to(device) pipe.vae = taef1 pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) print_progress("Start Inference.") for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", good_vae=good_vae, sigmas=sigmas, ): yield img else: # with ControlNet pipe.to(device) pipe.vae = good_vae if controlnet_union is not None: controlnet_union.to(device) if controlnet is not None: controlnet.to(device) pipe.enable_model_cpu_offload() print_progress("Start Inference with ControlNet.") for img in pipe( prompt=prompt_mash, control_image=images, control_mode=modes, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, controlnet_conditioning_scale=scales, generator=generator, joint_attention_kwargs={"scale": 1.0}, sigmas=sigmas, ).images: yield img except Exception as e: print(e) raise gr.Error(f"Inference Error: {e}") from e @spaces.GPU(duration=70) @torch.inference_mode() def generate_image_to_image(prompt_mash: str, image_input_path_dict: dict, image_strength: float, task_type: str, blur_mask: bool, blur_factor: float, steps: int, cfg_scale: float, width: int, height: int, sigmas_factor: float, seed: int, cn_on: bool, progress=gr.Progress(track_tqdm=True)): global pipe_i2i, good_vae, controlnet, controlnet_union try: good_vae.to(device) generator = torch.Generator(device=device).manual_seed(int(float(seed))) image_input_path = image_input_path_dict['background'] mask_path = image_input_path_dict['layers'][0] is_mask = True if task_type == "Inpainting" or task_type == "Flux Fill" else False is_fill = True if task_type == "Flux Fill" else False is_depth = True if task_type == "Depth" else False is_canny = True if task_type == "Canny" else False is_edit = True if task_type == "Edit" else False kwargs = {} if task_type in ["Image-to-Image", "Inpainting"]: kwargs["strength"] = image_strength if sigmas_factor < 1.0 and task_type != "Image-to-Image": kwargs["sigmas"] = calc_sigmas(steps, sigmas_factor) with calculateDuration("Generating image"): # Generate image modes, images, scales = get_control_params() if not cn_on or len(modes) == 0: # without ControlNet pipe_i2i.to(device) pipe_i2i.vae = good_vae image_input = load_image(image_input_path) if task_type in TASK_TYPE_CONTROL: kwargs["control_image"] = image_input else: kwargs["image"] = image_input if is_mask: mask_input = load_image(mask_path) if blur_mask: mask_input = pipe_i2i.mask_processor.blur(mask_input, blur_factor=blur_factor) kwargs["mask_image"] = mask_input if is_fill: print_progress("Start Flux Fill Inference.") else: print_progress("Start Inpainting Inference.") elif is_canny: image_input = get_canny_image(image_input, height, width) print_progress("Start Canny Inference.") elif is_depth: image_input = get_depth_image(image_input, height, width) print_progress("Start Depth Inference.") elif is_edit: print_progress("Start Edit Inference.") else: print_progress("Start I2I Inference.") final_image = pipe_i2i( prompt=prompt_mash, #image=image_input, num_inference_steps=steps, guidance_scale=cfg_scale, #width=width, #height=height, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", **kwargs, ).images[0] return final_image else: # with ControlNet pipe_i2i.to(device) pipe_i2i.vae = good_vae image_input = load_image(image_input_path) kwargs["image"] = image_input if controlnet_union is not None: controlnet_union.to(device) if controlnet is not None: controlnet.to(device) if is_mask: mask_input = load_image(mask_path) if blur_mask: mask_input = pipe_i2i.mask_processor.blur(mask_input, blur_factor=blur_factor) kwargs["mask_image"] = mask_input if is_fill: print_progress("Start Flux Fill Inference with ControlNet.") else: print_progress("Start Inpainting Inference with ControlNet.") else: print_progress("Start I2I Inference with ControlNet.") pipe_i2i.enable_model_cpu_offload() final_image = pipe_i2i( prompt=prompt_mash, control_image=images, control_mode=modes, image=image_input, num_inference_steps=steps, guidance_scale=cfg_scale, #width=width, #height=height, controlnet_conditioning_scale=scales, generator=generator, joint_attention_kwargs={"scale": 1.0}, output_type="pil", **kwargs, ).images[0] return final_image except Exception as e: print(e) raise gr.Error(f"I2I Inference Error: {e}") from e def run_lora(prompt: str, image_input: dict, image_strength: float, task_type: str, turbo_mode: str, blur_mask: bool, blur_factor: float, cfg_scale: float, steps: int, selected_indices, lora_scale_1: float, lora_scale_2: float, randomize_seed: bool, seed: int, width: int, height: int, sigmas_factor: float, loras_state, lora_json, cn_on: bool, translate_on: bool, progress=gr.Progress(track_tqdm=True)): global pipe, pipe_i2i if not selected_indices and not is_valid_lora(lora_json): gr.Info("LoRA isn't selected.") # raise gr.Error("You must select a LoRA before proceeding.") progress(0, desc="Preparing Inference.") selected_loras = [loras_state[idx] for idx in selected_indices] if task_type in set(TASK_TYPE_I2I): is_i2i = True else: is_i2i = False if translate_on: prompt = translate_to_en(prompt) # Build the prompt with trigger words prepends = [] appends = [] for lora in selected_loras: trigger_word = lora.get('trigger_word', '') if trigger_word: if lora.get("trigger_position") == "prepend": prepends.append(trigger_word) else: appends.append(trigger_word) prompt_mash = " ".join(prepends + [prompt] + appends) print("Prompt Mash: ", prompt_mash) # # Unload previous LoRA weights with calculateDuration("Unloading LoRA"): unload_lora() print(pipe.get_active_adapters()) # print(pipe_i2i.get_active_adapters()) # clear_cache() # # Build the prompt for External LoRAs prompt_mash = prompt_mash + get_model_trigger(last_model) lora_names = [] lora_weights = [] # Load Turbo LoRA weights if turbo_mode != "None": if is_i2i: pipe_i2i, lora_names, lora_weights, steps = turbo_loras(pipe_i2i, turbo_mode, lora_names, lora_weights) else: pipe, lora_names, lora_weights, steps = turbo_loras(pipe, turbo_mode, lora_names, lora_weights) # Load External LoRA weights if is_valid_lora(lora_json): with calculateDuration("Loading External LoRA weights"): if is_i2i: pipe_i2i, lora_names, lora_weights = fuse_loras(pipe_i2i, lora_json, lora_names, lora_weights) else: pipe, lora_names, lora_weights = fuse_loras(pipe, lora_json, lora_names, lora_weights) trigger_word = get_trigger_word(lora_json) prompt_mash = f"{prompt_mash} {trigger_word}" print("Prompt Mash: ", prompt_mash) # # Load LoRA weights with respective scales if selected_indices: with calculateDuration("Loading LoRA weights"): for idx, lora in enumerate(selected_loras): lora_name = f"lora_{idx}" lora_names.append(lora_name) print(f"Lora Name: {lora_name}") lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2) lora_path = lora['repo'] weight_name = lora.get("weights") print(f"Lora Path: {lora_path}") if is_i2i: pipe_i2i.load_lora_weights( lora_path, weight_name=weight_name if weight_name else None, low_cpu_mem_usage=False, adapter_name=lora_name, token=HF_TOKEN ) else: pipe.load_lora_weights( lora_path, weight_name=weight_name if weight_name else None, low_cpu_mem_usage=False, adapter_name=lora_name, token=HF_TOKEN ) print("Loaded LoRAs:", lora_names) if selected_indices or is_valid_lora(lora_json): if is_i2i: pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights) else: pipe.set_adapters(lora_names, adapter_weights=lora_weights) print(pipe.get_active_adapters()) # print(pipe_i2i.get_active_adapters()) # # Set random seed for reproducibility with calculateDuration("Randomizing seed"): if randomize_seed: seed = random.randint(0, MAX_SEED) # Generate image progress(0, desc="Running Inference.") if is_i2i: final_image = generate_image_to_image(prompt_mash, image_input, image_strength, task_type, blur_mask, blur_factor, steps, cfg_scale, width, height, sigmas_factor, seed, cn_on) yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False) else: image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, sigmas_factor, cn_on) # Consume the generator to get the final image final_image = None step_counter = 0 for image in image_generator: step_counter+=1 final_image = image progress_bar = f'