import modules.core as core import os import torch import modules.patch import modules.config import ldm_patched.modules.model_management import ldm_patched.modules.latent_formats import modules.inpaint_worker import extras.vae_interpose as vae_interpose from extras.expansion import FooocusExpansion from ldm_patched.modules.model_base import SDXL, SDXLRefiner from modules.sample_hijack import clip_separate model_base = core.StableDiffusionModel() model_refiner = core.StableDiffusionModel() final_expansion = None final_unet = None final_clip = None final_vae = None final_refiner_unet = None final_refiner_vae = None loaded_ControlNets = {} @torch.no_grad() @torch.inference_mode() def refresh_controlnets(model_paths): global loaded_ControlNets cache = {} for p in model_paths: if p is not None: if p in loaded_ControlNets: cache[p] = loaded_ControlNets[p] else: cache[p] = core.load_controlnet(p) loaded_ControlNets = cache return @torch.no_grad() @torch.inference_mode() def assert_model_integrity(): error_message = None if not isinstance(model_base.unet_with_lora.model, SDXL): error_message = 'You have selected base model other than SDXL. This is not supported yet.' if error_message is not None: raise NotImplementedError(error_message) return True @torch.no_grad() @torch.inference_mode() def refresh_base_model(name): global model_base filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name))) if model_base.filename == filename: return model_base = core.StableDiffusionModel() model_base = core.load_model(filename) print(f'Base model loaded: {model_base.filename}') return @torch.no_grad() @torch.inference_mode() def refresh_refiner_model(name): global model_refiner filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name))) if model_refiner.filename == filename: return model_refiner = core.StableDiffusionModel() if name == 'None': print(f'Refiner unloaded.') return model_refiner = core.load_model(filename) print(f'Refiner model loaded: {model_refiner.filename}') if isinstance(model_refiner.unet.model, SDXL): model_refiner.clip = None model_refiner.vae = None elif isinstance(model_refiner.unet.model, SDXLRefiner): model_refiner.clip = None model_refiner.vae = None else: model_refiner.clip = None return @torch.no_grad() @torch.inference_mode() def synthesize_refiner_model(): global model_base, model_refiner print('Synthetic Refiner Activated') model_refiner = core.StableDiffusionModel( unet=model_base.unet, vae=model_base.vae, clip=model_base.clip, clip_vision=model_base.clip_vision, filename=model_base.filename ) model_refiner.vae = None model_refiner.clip = None model_refiner.clip_vision = None return @torch.no_grad() @torch.inference_mode() def refresh_loras(loras, base_model_additional_loras=None): global model_base, model_refiner if not isinstance(base_model_additional_loras, list): base_model_additional_loras = [] model_base.refresh_loras(loras + base_model_additional_loras) model_refiner.refresh_loras(loras) return @torch.no_grad() @torch.inference_mode() def clip_encode_single(clip, text, verbose=False): cached = clip.fcs_cond_cache.get(text, None) if cached is not None: if verbose: print(f'[CLIP Cached] {text}') return cached tokens = clip.tokenize(text) result = clip.encode_from_tokens(tokens, return_pooled=True) clip.fcs_cond_cache[text] = result if verbose: print(f'[CLIP Encoded] {text}') return result @torch.no_grad() @torch.inference_mode() def clone_cond(conds): results = [] for c, p in conds: p = p["pooled_output"] if isinstance(c, torch.Tensor): c = c.clone() if isinstance(p, torch.Tensor): p = p.clone() results.append([c, {"pooled_output": p}]) return results @torch.no_grad() @torch.inference_mode() def clip_encode(texts, pool_top_k=1): global final_clip if final_clip is None: return None if not isinstance(texts, list): return None if len(texts) == 0: return None cond_list = [] pooled_acc = 0 for i, text in enumerate(texts): cond, pooled = clip_encode_single(final_clip, text) cond_list.append(cond) if i < pool_top_k: pooled_acc += pooled return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]] @torch.no_grad() @torch.inference_mode() def clear_all_caches(): final_clip.fcs_cond_cache = {} @torch.no_grad() @torch.inference_mode() def prepare_text_encoder(async_call=True): if async_call: # TODO: make sure that this is always called in an async way so that users cannot feel it. pass assert_model_integrity() ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher]) return @torch.no_grad() @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None, use_synthetic_refiner=False): global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None final_clip = None final_vae = None final_refiner_unet = None final_refiner_vae = None if use_synthetic_refiner and refiner_model_name == 'None': print('Synthetic Refiner Activated') refresh_base_model(base_model_name) synthesize_refiner_model() else: refresh_refiner_model(refiner_model_name) refresh_base_model(base_model_name) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) assert_model_integrity() final_unet = model_base.unet_with_lora final_clip = model_base.clip_with_lora final_vae = model_base.vae final_refiner_unet = model_refiner.unet_with_lora final_refiner_vae = model_refiner.vae if final_expansion is None: final_expansion = FooocusExpansion() prepare_text_encoder(async_call=True) clear_all_caches() return refresh_everything( refiner_model_name=modules.config.default_refiner_model_name, base_model_name=modules.config.default_base_model_name, loras=modules.config.default_loras ) @torch.no_grad() @torch.inference_mode() def vae_parse(latent): if final_refiner_vae is None: return latent result = vae_interpose.parse(latent["samples"]) return {'samples': result} @torch.no_grad() @torch.inference_mode() def calculate_sigmas_all(sampler, model, scheduler, steps): from ldm_patched.modules.samplers import calculate_sigmas_scheduler discard_penultimate_sigma = False if sampler in ['dpm_2', 'dpm_2_ancestral']: steps += 1 discard_penultimate_sigma = True sigmas = calculate_sigmas_scheduler(model, scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) return sigmas @torch.no_grad() @torch.inference_mode() def calculate_sigmas(sampler, model, scheduler, steps, denoise): if denoise is None or denoise > 0.9999: sigmas = calculate_sigmas_all(sampler, model, scheduler, steps) else: new_steps = int(steps / denoise) sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps) sigmas = sigmas[-(steps + 1):] return sigmas @torch.no_grad() @torch.inference_mode() def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'): assert refiner_swap_method in ['joint', 'separate', 'vae'] if final_refiner_vae is not None and final_refiner_unet is not None: if denoise > 0.9: return final_vae, final_refiner_vae else: if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834 return final_vae, None else: return final_refiner_vae, None return final_vae, final_refiner_vae @torch.no_grad() @torch.inference_mode() def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint'): target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \ = final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip assert refiner_swap_method in ['joint', 'separate', 'vae'] if final_refiner_vae is not None and final_refiner_unet is not None: # Refiner Use Different VAE (then it is SD15) if denoise > 0.9: refiner_swap_method = 'vae' else: refiner_swap_method = 'joint' if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834 target_unet, target_vae, target_refiner_unet, target_refiner_vae \ = final_unet, final_vae, None, None print(f'[Sampler] only use Base because of partial denoise.') else: positive_cond = clip_separate(positive_cond, target_model=final_refiner_unet.model, target_clip=final_clip) negative_cond = clip_separate(negative_cond, target_model=final_refiner_unet.model, target_clip=final_clip) target_unet, target_vae, target_refiner_unet, target_refiner_vae \ = final_refiner_unet, final_refiner_vae, None, None print(f'[Sampler] only use Refiner because of partial denoise.') print(f'[Sampler] refiner_swap_method = {refiner_swap_method}') if latent is None: initial_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) else: initial_latent = latent minmax_sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise) sigma_min, sigma_max = minmax_sigmas[minmax_sigmas > 0].min(), minmax_sigmas.max() sigma_min = float(sigma_min.cpu().numpy()) sigma_max = float(sigma_max.cpu().numpy()) print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}') modules.patch.BrownianTreeNoiseSamplerPatched.global_init( initial_latent['samples'].to(ldm_patched.modules.model_management.get_torch_device()), sigma_min, sigma_max, seed=image_seed, cpu=False) decoded_latent = None if refiner_swap_method == 'joint': sampled_latent = core.ksampler( model=target_unet, refiner=target_refiner_unet, positive=positive_cond, negative=negative_cond, latent=initial_latent, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, seed=image_seed, denoise=denoise, callback_function=callback, cfg=cfg_scale, sampler_name=sampler_name, scheduler=scheduler_name, refiner_switch=switch, previewer_start=0, previewer_end=steps, ) decoded_latent = core.decode_vae(vae=target_vae, latent_image=sampled_latent, tiled=tiled) if refiner_swap_method == 'separate': sampled_latent = core.ksampler( model=target_unet, positive=positive_cond, negative=negative_cond, latent=initial_latent, steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=False, seed=image_seed, denoise=denoise, callback_function=callback, cfg=cfg_scale, sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=0, previewer_end=steps, ) print('Refiner swapped by changing ksampler. Noise preserved.') target_model = target_refiner_unet if target_model is None: target_model = target_unet print('Use base model to refine itself - this may because of developer mode.') sampled_latent = core.ksampler( model=target_model, positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip), negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip), latent=sampled_latent, steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True, seed=image_seed, denoise=denoise, callback_function=callback, cfg=cfg_scale, sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=switch, previewer_end=steps, ) target_model = target_refiner_vae if target_model is None: target_model = target_vae decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) if refiner_swap_method == 'vae': modules.patch.eps_record = 'vae' if modules.inpaint_worker.current_task is not None: modules.inpaint_worker.current_task.unswap() sampled_latent = core.ksampler( model=target_unet, positive=positive_cond, negative=negative_cond, latent=initial_latent, steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=True, seed=image_seed, denoise=denoise, callback_function=callback, cfg=cfg_scale, sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=0, previewer_end=steps ) print('Fooocus VAE-based swap.') target_model = target_refiner_unet if target_model is None: target_model = target_unet print('Use base model to refine itself - this may because of developer mode.') sampled_latent = vae_parse(sampled_latent) k_sigmas = 1.4 sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=target_model.model, steps=steps, denoise=denoise)[switch:] * k_sigmas len_sigmas = len(sigmas) - 1 noise_mean = torch.mean(modules.patch.eps_record, dim=1, keepdim=True) if modules.inpaint_worker.current_task is not None: modules.inpaint_worker.current_task.swap() sampled_latent = core.ksampler( model=target_model, positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip), negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip), latent=sampled_latent, steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True, seed=image_seed+1, denoise=denoise, callback_function=callback, cfg=cfg_scale, sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=switch, previewer_end=steps, sigmas=sigmas, noise_mean=noise_mean ) target_model = target_refiner_vae if target_model is None: target_model = target_vae decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) images = core.pytorch_to_numpy(decoded_latent) modules.patch.eps_record = None return images