|
|
|
|
|
|
|
|
|
|
|
import comfy.model_patcher |
|
import comfy.samplers |
|
|
|
class PerturbedAttentionGuidance: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("MODEL",), |
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "patch" |
|
|
|
CATEGORY = "model_patches/unet" |
|
|
|
def patch(self, model, scale): |
|
unet_block = "middle" |
|
unet_block_id = 0 |
|
m = model.clone() |
|
|
|
def perturbed_attention(q, k, v, extra_options, mask=None): |
|
return v |
|
|
|
def post_cfg_function(args): |
|
model = args["model"] |
|
cond_pred = args["cond_denoised"] |
|
cond = args["cond"] |
|
cfg_result = args["denoised"] |
|
sigma = args["sigma"] |
|
model_options = args["model_options"].copy() |
|
x = args["input"] |
|
|
|
if scale == 0: |
|
return cfg_result |
|
|
|
|
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id) |
|
(pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) |
|
|
|
return cfg_result + (cond_pred - pag) * scale |
|
|
|
m.set_model_sampler_post_cfg_function(post_cfg_function) |
|
|
|
return (m,) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"PerturbedAttentionGuidance": PerturbedAttentionGuidance, |
|
} |
|
|