Spaces:
Runtime error
Runtime error
File size: 6,391 Bytes
68b7092 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
import fcbh.samplers
import fcbh.model_management
from fcbh.model_base import SDXLRefiner, SDXL
from fcbh.conds import CONDRegular
from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models
from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \
create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds
current_refiner = None
refiner_switch_step = -1
@torch.no_grad()
@torch.inference_mode()
def clip_separate_inner(c, p, target_model=None, target_clip=None):
if target_model is None or isinstance(target_model, SDXLRefiner):
c = c[..., -1280:].clone()
elif isinstance(target_model, SDXL):
c = c.clone()
else:
p = None
c = c[..., :768].clone()
final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm
final_layer_norm_origin_device = final_layer_norm.weight.device
final_layer_norm_origin_dtype = final_layer_norm.weight.dtype
c_origin_device = c.device
c_origin_dtype = c.dtype
final_layer_norm.to(device='cpu', dtype=torch.float32)
c = c.to(device='cpu', dtype=torch.float32)
c = torch.chunk(c, int(c.size(1)) // 77, 1)
c = [final_layer_norm(ci) for ci in c]
c = torch.cat(c, dim=1)
final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype)
c = c.to(device=c_origin_device, dtype=c_origin_dtype)
return c, p
@torch.no_grad()
@torch.inference_mode()
def clip_separate(cond, target_model=None, target_clip=None):
results = []
for c, px in cond:
p = px.get('pooled_output', None)
c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
p = {} if p is None else {'pooled_output': p.clone()}
results.append([c, p])
return results
@torch.no_grad()
@torch.inference_mode()
def clip_separate_after_preparation(cond, target_model=None, target_clip=None):
results = []
for x in cond:
p = x.get('pooled_output', None)
c = x['model_conds']['c_crossattn'].cond
c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
result = {'model_conds': {'c_crossattn': CONDRegular(c)}}
if p is not None:
result['pooled_output'] = p.clone()
results.append(result)
return results
@torch.no_grad()
@torch.inference_mode()
def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
global current_refiner
positive = positive[:]
negative = negative[:]
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
# pre_run_control(model, negative + positive)
pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster.
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'):
positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model)
negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model)
positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
def refiner_switch():
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
extra_args["cond"] = positive_refiner
extra_args["uncond"] = negative_refiner
# clear ip-adapter for refiner
extra_args['model_options'] = {k: {} if k == 'transformer_options' else v for k, v in extra_args['model_options'].items()}
models, inference_memory = get_additional_models(positive_refiner, negative_refiner, current_refiner.model_dtype())
fcbh.model_management.load_models_gpu([current_refiner] + models, current_refiner.memory_required(noise.shape) + inference_memory)
model_wrap.inner_model = current_refiner.model
print('Refiner Swapped')
return
def callback_wrap(step, x0, x, total_steps):
if step == refiner_switch_step and current_refiner is not None:
refiner_switch()
if callback is not None:
# residual_noise_preview = x - x0
# residual_noise_preview /= residual_noise_preview.std()
# residual_noise_preview *= x0.std()
callback(step, x0, x, total_steps)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback_wrap, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
fcbh.samplers.sample = sample_hacked
|