|
import torch |
|
import comfy.utils |
|
import math |
|
|
|
class PatchModelAddDownscale_MK3: |
|
"""A UNet model patch that implements advanced dynamic latent downscaling with multiple transition modes. |
|
|
|
This node is an enhanced version of PatchModelAddDownscale_v2 that adds multiple transition modes, |
|
adaptive scaling, and performance optimizations. It operates in three main phases with configurable |
|
behaviors: |
|
|
|
1. Full Downscale Phase (start_percent → end_percent): |
|
- Latents are downscaled by the specified downscale_factor |
|
- Optional dynamic factor adjustment based on latent size |
|
- Supports minimum size constraints |
|
|
|
2. Transition Phase (end_percent → gradual_percent): |
|
Multiple transition modes available: |
|
- LINEAR: Smooth linear interpolation (original v2 behavior) |
|
- COSINE: Smooth cosine interpolation for more natural transitions |
|
- EXPONENTIAL: Quick initial change that slows down |
|
- LOGARITHMIC: Slow initial change that speeds up |
|
- STEP: Discrete steps for controlled transitions |
|
|
|
3. Final Phase (after gradual_percent): |
|
- Latents remain at their original size |
|
- Optional post-processing effects |
|
|
|
Advanced Features: |
|
- Adaptive scaling based on input latent dimensions |
|
- Multiple interpolation algorithms for both downscaling and upscaling |
|
- Dynamic minimum size constraints to prevent over-shrinking |
|
- Optional skip connection handling modes |
|
- Memory optimization for large batch processing |
|
- Automatic scale factor adjustment for extreme aspect ratios |
|
|
|
Parameters: |
|
model: The model to patch |
|
block_number: Which UNet block to apply the patch to (1-32) |
|
downscale_factor: Base shrink factor (0.1-9.0) |
|
start_percent: When to start downscaling (0.0-1.0) |
|
end_percent: When to begin transitioning back (0.0-1.0) |
|
gradual_percent: When to complete the transition (0.0-1.0) |
|
transition_mode: Algorithm for size transition |
|
min_size: Minimum allowed dimension in pixels |
|
adaptive_scaling: Enable dynamic factor adjustment |
|
downscale_after_skip: Apply downscaling after skip connections |
|
downscale_method: Algorithm for downscaling |
|
upscale_method: Algorithm for upscaling |
|
preserve_aspect: Maintain aspect ratio during scaling |
|
|
|
Example Usage: |
|
To create a gentle transition with cosine interpolation: |
|
```python |
|
patch = PatchModelAddDownscale_MK3( |
|
model=model, |
|
block_number=3, |
|
downscale_factor=2.0, |
|
start_percent=0.0, |
|
end_percent=0.35, |
|
gradual_percent=0.6, |
|
transition_mode='COSINE' |
|
) |
|
``` |
|
|
|
Code by: |
|
- Original: https://github.com/Jordach + comfyanon + kohya-ss |
|
""" |
|
|
|
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] |
|
transition_modes = ["LINEAR", "COSINE", "EXPONENTIAL", "LOGARITHMIC", "STEP"] |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), |
|
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), |
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), |
|
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), |
|
"gradual_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001}), |
|
"transition_mode": (s.transition_modes,), |
|
"downscale_after_skip": ("BOOLEAN", {"default": True}), |
|
"downscale_method": (s.upscale_methods,), |
|
"upscale_method": (s.upscale_methods,), |
|
"min_size": ("INT", {"default": 64, "min": 16, "max": 2048, "step": 8}), |
|
"adaptive_scaling": ("BOOLEAN", {"default": True}), |
|
"preserve_aspect": ("BOOLEAN", {"default": True}), |
|
}} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "patch" |
|
CATEGORY = "model_patches/unet" |
|
|
|
def calculate_transition_factor(self, current_percent, end_percent, gradual_percent, |
|
downscale_factor, mode="LINEAR"): |
|
"""Calculate the scaling factor based on the selected transition mode""" |
|
if current_percent <= end_percent: |
|
return 1.0 / downscale_factor |
|
elif current_percent >= gradual_percent: |
|
return 1.0 |
|
|
|
|
|
progress = (current_percent - end_percent) / (gradual_percent - end_percent) |
|
|
|
|
|
if mode == "LINEAR": |
|
factor = progress |
|
elif mode == "COSINE": |
|
factor = (1 - math.cos(progress * math.pi)) / 2 |
|
elif mode == "EXPONENTIAL": |
|
factor = math.pow(progress, 2) |
|
elif mode == "LOGARITHMIC": |
|
factor = math.log(1 + progress * (math.e - 1)) |
|
elif mode == "STEP": |
|
factor = round(progress * 4) / 4 |
|
|
|
|
|
scale_diff = 1.0 - (1.0 / downscale_factor) |
|
return (1.0 / downscale_factor) + (scale_diff * factor) |
|
|
|
def calculate_adaptive_factor(self, h, base_factor, min_size): |
|
"""Adjust scaling factor based on input dimensions and constraints""" |
|
min_dim = min(h.shape[-2:]) |
|
max_dim = max(h.shape[-2:]) |
|
aspect_ratio = max_dim / min_dim |
|
|
|
|
|
max_allowed_factor = min_dim / min_size |
|
adjusted_factor = min(base_factor, max_allowed_factor) |
|
|
|
|
|
if aspect_ratio > 2: |
|
adjusted_factor *= math.sqrt(2 / aspect_ratio) |
|
|
|
return adjusted_factor |
|
|
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, |
|
gradual_percent, transition_mode, downscale_after_skip, downscale_method, |
|
upscale_method, min_size, adaptive_scaling, preserve_aspect): |
|
|
|
model_sampling = model.get_model_object("model_sampling") |
|
sigma_start = model_sampling.percent_to_sigma(start_percent) |
|
sigma_end = model_sampling.percent_to_sigma(end_percent) |
|
sigma_rescale = model_sampling.percent_to_sigma(gradual_percent) |
|
|
|
def input_block_patch(h, transformer_options): |
|
if downscale_factor == 1: |
|
return h |
|
|
|
if transformer_options["block"][1] == block_number: |
|
sigma = transformer_options["sigmas"][0].item() |
|
|
|
|
|
if adaptive_scaling: |
|
effective_factor = self.calculate_adaptive_factor(h, downscale_factor, min_size) |
|
else: |
|
effective_factor = downscale_factor |
|
|
|
|
|
if sigma <= sigma_start and sigma >= sigma_end: |
|
scale_factor = 1.0 / effective_factor |
|
elif sigma < sigma_end and sigma >= sigma_rescale: |
|
scale_factor = self.calculate_transition_factor( |
|
sigma, sigma_rescale, sigma_end, |
|
effective_factor, transition_mode |
|
) |
|
else: |
|
return h |
|
|
|
|
|
if preserve_aspect: |
|
new_h = round(h.shape[-2] * scale_factor) |
|
new_w = round(h.shape[-1] * scale_factor) |
|
else: |
|
|
|
new_h = max(round(h.shape[-2] * scale_factor), min_size) |
|
new_w = max(round(h.shape[-1] * scale_factor), min_size) |
|
|
|
h = comfy.utils.common_upscale( |
|
h, new_w, new_h, |
|
downscale_method if scale_factor < 1 else upscale_method, |
|
"disabled" |
|
) |
|
return h |
|
|
|
def output_block_patch(h, hsp, transformer_options): |
|
if h.shape[2:] != hsp.shape[2:]: |
|
h = comfy.utils.common_upscale( |
|
h, hsp.shape[-1], hsp.shape[-2], |
|
upscale_method, "disabled" |
|
) |
|
return h, hsp |
|
|
|
m = model.clone() |
|
if downscale_after_skip: |
|
m.set_model_input_block_patch_after_skip(input_block_patch) |
|
else: |
|
m.set_model_input_block_patch(input_block_patch) |
|
m.set_model_output_block_patch(output_block_patch) |
|
return (m, ) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"PatchModelAddDownscale_MK3": PatchModelAddDownscale_MK3, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"PatchModelAddDownscale_MK3": "PatchModelAddDownscale MK3", |
|
} |