toolkit / comfy_nodes /deep_shrink_mk3.py
k4d3's picture
commit
524ee6e
import torch
import comfy.utils
import math
class PatchModelAddDownscale_MK3:
"""A UNet model patch that implements advanced dynamic latent downscaling with multiple transition modes.
This node is an enhanced version of PatchModelAddDownscale_v2 that adds multiple transition modes,
adaptive scaling, and performance optimizations. It operates in three main phases with configurable
behaviors:
1. Full Downscale Phase (start_percent → end_percent):
- Latents are downscaled by the specified downscale_factor
- Optional dynamic factor adjustment based on latent size
- Supports minimum size constraints
2. Transition Phase (end_percent → gradual_percent):
Multiple transition modes available:
- LINEAR: Smooth linear interpolation (original v2 behavior)
- COSINE: Smooth cosine interpolation for more natural transitions
- EXPONENTIAL: Quick initial change that slows down
- LOGARITHMIC: Slow initial change that speeds up
- STEP: Discrete steps for controlled transitions
3. Final Phase (after gradual_percent):
- Latents remain at their original size
- Optional post-processing effects
Advanced Features:
- Adaptive scaling based on input latent dimensions
- Multiple interpolation algorithms for both downscaling and upscaling
- Dynamic minimum size constraints to prevent over-shrinking
- Optional skip connection handling modes
- Memory optimization for large batch processing
- Automatic scale factor adjustment for extreme aspect ratios
Parameters:
model: The model to patch
block_number: Which UNet block to apply the patch to (1-32)
downscale_factor: Base shrink factor (0.1-9.0)
start_percent: When to start downscaling (0.0-1.0)
end_percent: When to begin transitioning back (0.0-1.0)
gradual_percent: When to complete the transition (0.0-1.0)
transition_mode: Algorithm for size transition
min_size: Minimum allowed dimension in pixels
adaptive_scaling: Enable dynamic factor adjustment
downscale_after_skip: Apply downscaling after skip connections
downscale_method: Algorithm for downscaling
upscale_method: Algorithm for upscaling
preserve_aspect: Maintain aspect ratio during scaling
Example Usage:
To create a gentle transition with cosine interpolation:
```python
patch = PatchModelAddDownscale_MK3(
model=model,
block_number=3,
downscale_factor=2.0,
start_percent=0.0,
end_percent=0.35,
gradual_percent=0.6,
transition_mode='COSINE'
)
```
Code by:
- Original: https://github.com/Jordach + comfyanon + kohya-ss
"""
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
transition_modes = ["LINEAR", "COSINE", "EXPONENTIAL", "LOGARITHMIC", "STEP"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"gradual_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001}),
"transition_mode": (s.transition_modes,),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
"min_size": ("INT", {"default": 64, "min": 16, "max": 2048, "step": 8}),
"adaptive_scaling": ("BOOLEAN", {"default": True}),
"preserve_aspect": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def calculate_transition_factor(self, current_percent, end_percent, gradual_percent,
downscale_factor, mode="LINEAR"):
"""Calculate the scaling factor based on the selected transition mode"""
if current_percent <= end_percent:
return 1.0 / downscale_factor
elif current_percent >= gradual_percent:
return 1.0
# Calculate base progress
progress = (current_percent - end_percent) / (gradual_percent - end_percent)
# Apply different transition curves
if mode == "LINEAR":
factor = progress
elif mode == "COSINE":
factor = (1 - math.cos(progress * math.pi)) / 2
elif mode == "EXPONENTIAL":
factor = math.pow(progress, 2)
elif mode == "LOGARITHMIC":
factor = math.log(1 + progress * (math.e - 1))
elif mode == "STEP":
factor = round(progress * 4) / 4 # 4 discrete steps
# Calculate final scale
scale_diff = 1.0 - (1.0 / downscale_factor)
return (1.0 / downscale_factor) + (scale_diff * factor)
def calculate_adaptive_factor(self, h, base_factor, min_size):
"""Adjust scaling factor based on input dimensions and constraints"""
min_dim = min(h.shape[-2:])
max_dim = max(h.shape[-2:])
aspect_ratio = max_dim / min_dim
# Prevent over-shrinking
max_allowed_factor = min_dim / min_size
adjusted_factor = min(base_factor, max_allowed_factor)
# Adjust for extreme aspect ratios
if aspect_ratio > 2:
adjusted_factor *= math.sqrt(2 / aspect_ratio)
return adjusted_factor
def patch(self, model, block_number, downscale_factor, start_percent, end_percent,
gradual_percent, transition_mode, downscale_after_skip, downscale_method,
upscale_method, min_size, adaptive_scaling, preserve_aspect):
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
sigma_rescale = model_sampling.percent_to_sigma(gradual_percent)
def input_block_patch(h, transformer_options):
if downscale_factor == 1:
return h
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
# Calculate effective scaling factor
if adaptive_scaling:
effective_factor = self.calculate_adaptive_factor(h, downscale_factor, min_size)
else:
effective_factor = downscale_factor
# Apply scaling based on current phase
if sigma <= sigma_start and sigma >= sigma_end:
scale_factor = 1.0 / effective_factor
elif sigma < sigma_end and sigma >= sigma_rescale:
scale_factor = self.calculate_transition_factor(
sigma, sigma_rescale, sigma_end,
effective_factor, transition_mode
)
else:
return h
# Calculate new dimensions
if preserve_aspect:
new_h = round(h.shape[-2] * scale_factor)
new_w = round(h.shape[-1] * scale_factor)
else:
# Independent scaling for width/height
new_h = max(round(h.shape[-2] * scale_factor), min_size)
new_w = max(round(h.shape[-1] * scale_factor), min_size)
h = comfy.utils.common_upscale(
h, new_w, new_h,
downscale_method if scale_factor < 1 else upscale_method,
"disabled"
)
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2:] != hsp.shape[2:]:
h = comfy.utils.common_upscale(
h, hsp.shape[-1], hsp.shape[-2],
upscale_method, "disabled"
)
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale_MK3": PatchModelAddDownscale_MK3,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PatchModelAddDownscale_MK3": "PatchModelAddDownscale MK3",
}