stuff 2.0
Browse files- caption/joy_single.py +236 -0
- comfy_nodes/deep_shrink_mk2.py +129 -0
- comfy_nodes/easy_aspects.py +103 -0
caption/joy_single.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
Simplified JoyCaption - Generates captions for a single image input
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import argparse
|
10 |
+
from pathlib import Path
|
11 |
+
from PIL import Image
|
12 |
+
import torch
|
13 |
+
import torchvision.transforms.functional as TVF
|
14 |
+
from transformers import (
|
15 |
+
AutoModel,
|
16 |
+
AutoTokenizer,
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
)
|
19 |
+
from torch import nn
|
20 |
+
import logging
|
21 |
+
|
22 |
+
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
23 |
+
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
24 |
+
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
|
25 |
+
|
26 |
+
class ImageAdapter(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
input_features: int,
|
30 |
+
output_features: int,
|
31 |
+
ln1: bool,
|
32 |
+
pos_emb: bool,
|
33 |
+
num_image_tokens: int,
|
34 |
+
deep_extract: bool,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.deep_extract = deep_extract
|
38 |
+
if self.deep_extract:
|
39 |
+
input_features = input_features * 5
|
40 |
+
|
41 |
+
self.linear1 = nn.Linear(input_features, output_features)
|
42 |
+
self.activation = nn.GELU()
|
43 |
+
self.linear2 = nn.Linear(output_features, output_features)
|
44 |
+
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
45 |
+
self.pos_emb = (
|
46 |
+
None
|
47 |
+
if not pos_emb
|
48 |
+
else nn.Parameter(torch.zeros(num_image_tokens, input_features))
|
49 |
+
)
|
50 |
+
self.other_tokens = nn.Embedding(3, output_features)
|
51 |
+
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
|
52 |
+
|
53 |
+
def forward(self, vision_outputs):
|
54 |
+
if self.deep_extract:
|
55 |
+
x = torch.concat(
|
56 |
+
(
|
57 |
+
vision_outputs[-2],
|
58 |
+
vision_outputs[3],
|
59 |
+
vision_outputs[7],
|
60 |
+
vision_outputs[13],
|
61 |
+
vision_outputs[20],
|
62 |
+
),
|
63 |
+
dim=-1,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
x = vision_outputs[-2]
|
67 |
+
|
68 |
+
x = self.ln1(x)
|
69 |
+
if self.pos_emb is not None:
|
70 |
+
x = x + self.pos_emb
|
71 |
+
|
72 |
+
x = self.linear1(x)
|
73 |
+
x = self.activation(x)
|
74 |
+
x = self.linear2(x)
|
75 |
+
|
76 |
+
other_tokens = self.other_tokens(
|
77 |
+
torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(
|
78 |
+
x.shape[0], -1
|
79 |
+
)
|
80 |
+
)
|
81 |
+
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
class SimpleCaptioner:
|
85 |
+
def __init__(self):
|
86 |
+
self.clip_model = None
|
87 |
+
self.text_model = None
|
88 |
+
self.image_adapter = None
|
89 |
+
self.tokenizer = None
|
90 |
+
|
91 |
+
def load_models(self):
|
92 |
+
logging.info("Loading CLIP")
|
93 |
+
self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
94 |
+
self.clip_model = self.clip_model.vision_model
|
95 |
+
|
96 |
+
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
97 |
+
checkpoint = torch.load(
|
98 |
+
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
99 |
+
)
|
100 |
+
checkpoint = {
|
101 |
+
k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
|
102 |
+
}
|
103 |
+
self.clip_model.load_state_dict(checkpoint)
|
104 |
+
|
105 |
+
self.clip_model.eval()
|
106 |
+
self.clip_model.requires_grad_(False)
|
107 |
+
self.clip_model.to("cuda")
|
108 |
+
|
109 |
+
logging.info("Loading tokenizer and LLM")
|
110 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
111 |
+
CHECKPOINT_PATH / "text_model", use_fast=True
|
112 |
+
)
|
113 |
+
|
114 |
+
if (CHECKPOINT_PATH / "text_model").exists():
|
115 |
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
116 |
+
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
120 |
+
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16
|
121 |
+
)
|
122 |
+
self.text_model.eval()
|
123 |
+
|
124 |
+
logging.info("Loading image adapter")
|
125 |
+
self.image_adapter = ImageAdapter(
|
126 |
+
self.clip_model.config.hidden_size,
|
127 |
+
self.text_model.config.hidden_size,
|
128 |
+
False,
|
129 |
+
False,
|
130 |
+
38,
|
131 |
+
False,
|
132 |
+
)
|
133 |
+
self.image_adapter.load_state_dict(
|
134 |
+
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
|
135 |
+
)
|
136 |
+
self.image_adapter.eval()
|
137 |
+
self.image_adapter.to("cuda")
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def generate_caption(self, image_path: str) -> str:
|
141 |
+
# Load and preprocess image
|
142 |
+
input_image = Image.open(image_path).convert("RGB")
|
143 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
144 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
145 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]).to("cuda")
|
146 |
+
|
147 |
+
# Generate image embeddings
|
148 |
+
vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
149 |
+
embedded_images = self.image_adapter(vision_outputs.hidden_states)
|
150 |
+
|
151 |
+
# Prepare prompt
|
152 |
+
prompt = "Write a descriptive caption for this image in a formal tone."
|
153 |
+
convo = [
|
154 |
+
{"role": "system", "content": "You are a helpful image captioner."},
|
155 |
+
{"role": "user", "content": prompt},
|
156 |
+
]
|
157 |
+
convo_string = self.tokenizer.apply_chat_template(
|
158 |
+
convo, tokenize=False, add_generation_prompt=True
|
159 |
+
)
|
160 |
+
|
161 |
+
# Tokenize and prepare inputs
|
162 |
+
convo_tokens = self.tokenizer.encode(
|
163 |
+
convo_string, return_tensors="pt", add_special_tokens=False
|
164 |
+
)
|
165 |
+
prompt_tokens = self.tokenizer.encode(
|
166 |
+
prompt, return_tensors="pt", add_special_tokens=False
|
167 |
+
)
|
168 |
+
|
169 |
+
eot_id_indices = (
|
170 |
+
(convo_tokens == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
|
171 |
+
.nonzero(as_tuple=True)[0]
|
172 |
+
.tolist()
|
173 |
+
)
|
174 |
+
preamble_len = eot_id_indices[1] - prompt_tokens.shape[1]
|
175 |
+
|
176 |
+
convo_embeds = self.text_model.model.embed_tokens(convo_tokens.to("cuda"))
|
177 |
+
|
178 |
+
input_embeds = torch.cat(
|
179 |
+
[
|
180 |
+
convo_embeds[:, :preamble_len],
|
181 |
+
embedded_images.to(dtype=convo_embeds.dtype),
|
182 |
+
convo_embeds[:, preamble_len:],
|
183 |
+
],
|
184 |
+
dim=1,
|
185 |
+
)
|
186 |
+
|
187 |
+
input_ids = torch.cat(
|
188 |
+
[
|
189 |
+
convo_tokens[:, :preamble_len],
|
190 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device="cuda"),
|
191 |
+
convo_tokens[:, preamble_len:],
|
192 |
+
],
|
193 |
+
dim=1,
|
194 |
+
)
|
195 |
+
|
196 |
+
attention_mask = torch.ones_like(input_ids)
|
197 |
+
|
198 |
+
# Generate caption
|
199 |
+
generate_ids = self.text_model.generate(
|
200 |
+
input_ids,
|
201 |
+
inputs_embeds=input_embeds,
|
202 |
+
attention_mask=attention_mask,
|
203 |
+
max_new_tokens=300,
|
204 |
+
do_sample=True,
|
205 |
+
repetition_penalty=1.2,
|
206 |
+
)
|
207 |
+
|
208 |
+
# Decode caption
|
209 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
210 |
+
if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
211 |
+
generate_ids = generate_ids[:, :-1]
|
212 |
+
|
213 |
+
caption = self.tokenizer.batch_decode(
|
214 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
215 |
+
)[0]
|
216 |
+
|
217 |
+
return caption.strip()
|
218 |
+
|
219 |
+
def main():
|
220 |
+
parser = argparse.ArgumentParser(description="Generate a caption for a single image")
|
221 |
+
parser.add_argument("image_path", type=str, help="Path to the input image")
|
222 |
+
args = parser.parse_args()
|
223 |
+
|
224 |
+
# Setup logging
|
225 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
|
226 |
+
|
227 |
+
# Initialize and load the captioner
|
228 |
+
captioner = SimpleCaptioner()
|
229 |
+
captioner.load_models()
|
230 |
+
|
231 |
+
# Generate and print caption
|
232 |
+
caption = captioner.generate_caption(args.image_path)
|
233 |
+
print(f"\nGenerated caption:\n{caption}")
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
main()
|
comfy_nodes/deep_shrink_mk2.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import comfy.utils
|
3 |
+
|
4 |
+
class PatchModelAddDownscale_v2:
|
5 |
+
"""A UNet model patch that implements dynamic latent downscaling with gradual transition.
|
6 |
+
|
7 |
+
This node is an enhanced version of the original PatchModelAddDownscale that adds smooth
|
8 |
+
transition capabilities. It operates in three phases:
|
9 |
+
|
10 |
+
1. Full Downscale (start_percent → end_percent):
|
11 |
+
Latents are downscaled by the specified downscale_factor
|
12 |
+
|
13 |
+
2. Gradual Transition (end_percent → gradual_percent):
|
14 |
+
Latents smoothly transition from downscaled size back to original size
|
15 |
+
|
16 |
+
3. Original Size (after gradual_percent):
|
17 |
+
Latents remain at their original size
|
18 |
+
|
19 |
+
The gradual transition helps prevent abrupt changes in the generation process,
|
20 |
+
potentially leading to more consistent results.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
model: The model to patch
|
24 |
+
block_number: Which UNet block to apply the patch to
|
25 |
+
downscale_factor: How much to shrink the latents by
|
26 |
+
start_percent: When to start downscaling (in terms of sampling progress)
|
27 |
+
end_percent: When to begin transitioning back to original size
|
28 |
+
gradual_percent: When to complete the transition to original size
|
29 |
+
downscale_after_skip: Whether to apply downscaling after skip connections
|
30 |
+
downscale_method: Algorithm to use for downscaling
|
31 |
+
upscale_method: Algorithm to use for upscaling
|
32 |
+
|
33 |
+
Code by:
|
34 |
+
- https://github.com/Jordach
|
35 |
+
"""
|
36 |
+
|
37 |
+
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def INPUT_TYPES(s):
|
41 |
+
return {"required": {
|
42 |
+
"model": ("MODEL",),
|
43 |
+
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
44 |
+
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
45 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
46 |
+
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
47 |
+
"gradual_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001}),
|
48 |
+
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
49 |
+
"downscale_method": (s.upscale_methods,),
|
50 |
+
"upscale_method": (s.upscale_methods,),
|
51 |
+
}}
|
52 |
+
|
53 |
+
RETURN_TYPES = ("MODEL",)
|
54 |
+
FUNCTION = "patch"
|
55 |
+
CATEGORY = "model_patches/unet"
|
56 |
+
|
57 |
+
def calculate_upscale_factor(self, current_percent, end_percent, gradual_percent, downscale_factor):
|
58 |
+
"""Calculate the upscale factor during the gradual resize phase"""
|
59 |
+
if current_percent <= end_percent:
|
60 |
+
return 1.0 / downscale_factor # Still fully downscaled
|
61 |
+
elif current_percent >= gradual_percent:
|
62 |
+
return 1.0 # Fully back to original size
|
63 |
+
else:
|
64 |
+
# Linear interpolation between downscaled and original size
|
65 |
+
progress = (current_percent - end_percent) / (gradual_percent - end_percent)
|
66 |
+
scale_diff = 1.0 - (1.0 / downscale_factor)
|
67 |
+
return (1.0 / downscale_factor) + (scale_diff * progress)
|
68 |
+
|
69 |
+
def patch(self, model, block_number, downscale_factor, start_percent, end_percent,
|
70 |
+
gradual_percent, downscale_after_skip, downscale_method, upscale_method):
|
71 |
+
model_sampling = model.get_model_object("model_sampling")
|
72 |
+
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
73 |
+
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
74 |
+
sigma_rescale = model_sampling.percent_to_sigma(gradual_percent)
|
75 |
+
|
76 |
+
def input_block_patch(h, transformer_options):
|
77 |
+
if downscale_factor == 1:
|
78 |
+
return h
|
79 |
+
|
80 |
+
if transformer_options["block"][1] == block_number:
|
81 |
+
sigma = transformer_options["sigmas"][0].item()
|
82 |
+
|
83 |
+
# Normal downscale behavior between start_percent and end_percent
|
84 |
+
if sigma <= sigma_start and sigma >= sigma_end:
|
85 |
+
h = comfy.utils.common_upscale(
|
86 |
+
h,
|
87 |
+
round(h.shape[-1] * (1.0 / downscale_factor)),
|
88 |
+
round(h.shape[-2] * (1.0 / downscale_factor)),
|
89 |
+
downscale_method,
|
90 |
+
"disabled"
|
91 |
+
)
|
92 |
+
# Gradually upscale latent after end_percent until gradual_percent
|
93 |
+
elif sigma < sigma_end and sigma >= sigma_rescale:
|
94 |
+
scale_factor = self.calculate_upscale_factor(
|
95 |
+
sigma, sigma_rescale, sigma_end, downscale_factor
|
96 |
+
)
|
97 |
+
h = comfy.utils.common_upscale(
|
98 |
+
h,
|
99 |
+
round(h.shape[-1] * scale_factor),
|
100 |
+
round(h.shape[-2] * scale_factor),
|
101 |
+
upscale_method,
|
102 |
+
"disabled"
|
103 |
+
)
|
104 |
+
return h
|
105 |
+
|
106 |
+
def output_block_patch(h, hsp, transformer_options):
|
107 |
+
if h.shape[2] != hsp.shape[2]:
|
108 |
+
h = comfy.utils.common_upscale(
|
109 |
+
h, hsp.shape[-1], hsp.shape[-2],
|
110 |
+
upscale_method, "disabled"
|
111 |
+
)
|
112 |
+
return h, hsp
|
113 |
+
|
114 |
+
m = model.clone()
|
115 |
+
if downscale_after_skip:
|
116 |
+
m.set_model_input_block_patch_after_skip(input_block_patch)
|
117 |
+
else:
|
118 |
+
m.set_model_input_block_patch(input_block_patch)
|
119 |
+
m.set_model_output_block_patch(output_block_patch)
|
120 |
+
return (m, )
|
121 |
+
|
122 |
+
NODE_CLASS_MAPPINGS = {
|
123 |
+
"PatchModelAddDownscale_v2": PatchModelAddDownscale_v2,
|
124 |
+
}
|
125 |
+
|
126 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
127 |
+
# Sampling
|
128 |
+
"PatchModelAddDownscale_v2": "PatchModelAddDownscale v2",
|
129 |
+
}
|
comfy_nodes/easy_aspects.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
class AutoImageSize:
|
4 |
+
"""A utility node that automatically calculates optimal image dimensions and parameters.
|
5 |
+
|
6 |
+
This node helps create properly scaled images while maintaining desired aspect ratios
|
7 |
+
and managing performance through compression factors. It also automatically adjusts
|
8 |
+
denoise strength based on the output resolution.
|
9 |
+
|
10 |
+
Features:
|
11 |
+
- Maintains exact aspect ratios while ensuring dimensions are divisible by the compression factor
|
12 |
+
- Automatically calculates appropriate denoise strength based on resolution scaling
|
13 |
+
- Supports both portrait and landscape orientations
|
14 |
+
- Prevents downscaling below base resolution
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
aspect_ratio: The desired width/height ratio (1.0 = square, >1 = wider/taller)
|
18 |
+
orientation: Whether the image should be portrait or landscape
|
19 |
+
target_resolution: The desired maximum dimension in pixels
|
20 |
+
base_resolution: The model's native resolution (usually 1024)
|
21 |
+
compression_factor: Ensures dimensions are divisible by this value (usually 8 for VAEs)
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
WIDTH: The calculated image width
|
25 |
+
HEIGHT: The calculated image height
|
26 |
+
DOWNSCALE_FACTOR: The scaling factor relative to base_resolution
|
27 |
+
DENOISE_STRENGTH: Automatically adjusted denoise strength (0.1-0.65)
|
28 |
+
|
29 |
+
The denoise strength calculation uses an exponential decay curve fitted to known good values:
|
30 |
+
- 1.0x (1024px) → 0.75
|
31 |
+
- 1.5x (1536px) → 0.45
|
32 |
+
- 2.0x (2048px) → 0.2
|
33 |
+
|
34 |
+
Code by:
|
35 |
+
- https://github.com/Jordach
|
36 |
+
"""
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def INPUT_TYPES(s):
|
40 |
+
return {
|
41 |
+
"required": {
|
42 |
+
"aspect_ratio": ("FLOAT", {"default": 1, "min": 1, "max": 8, "step": 0.01}),
|
43 |
+
"orientation": (["portrait", "landscape"],),
|
44 |
+
"target_resolution": ("INT", {"default": 1024, "min": 256, "max": 1024*8, "step": 1}),
|
45 |
+
"base_resolution": ("INT", {"default": 1024, "min": 256, "max": 1024*8, "step": 1}),
|
46 |
+
"compression_factor": ("INT", {"default": 8, "min": 1, "max": 64, "step": 1}),
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
RETURN_TYPES = ("INT", "INT", "FLOAT", "FLOAT")
|
51 |
+
RETURN_NAMES = ("WIDTH", "HEIGHT", "DOWNSCALE_FACTOR", "DENOISE_STRENGTH")
|
52 |
+
FUNCTION = "create_res"
|
53 |
+
|
54 |
+
CATEGORY = "utils"
|
55 |
+
|
56 |
+
def calculate_denoise_strength(self, scale_factor):
|
57 |
+
"""
|
58 |
+
Calculate appropriate denoise strength based on resolution scale factor.
|
59 |
+
Uses exponential decay curve fitted to known good values:
|
60 |
+
- 1.0x (1024px) → 0.75
|
61 |
+
- 1.5x (1536px) → 0.45
|
62 |
+
- 2.0x (2048px) → 0.2
|
63 |
+
"""
|
64 |
+
# Base denoise value for 1024px (scale_factor = 1.0)
|
65 |
+
base_denoise = 0.95
|
66 |
+
|
67 |
+
# Calculate denoise strength using exponential decay
|
68 |
+
# Formula: denoise = base_denoise * e^(-k * (scale_factor - 1))
|
69 |
+
# where k is calculated to fit our known points
|
70 |
+
# Decay constant fitted to match reference points
|
71 |
+
k = 1.55
|
72 |
+
|
73 |
+
denoise = base_denoise * math.exp(-k * (scale_factor - 1))
|
74 |
+
d_min = 0.1
|
75 |
+
d_max = 0.65
|
76 |
+
# Clamp the result between 0.1 and 0.6
|
77 |
+
return max(d_min, min(d_max, denoise))
|
78 |
+
|
79 |
+
def create_res(self, aspect_ratio, orientation, target_resolution, base_resolution, compression_factor):
|
80 |
+
# Prevent cases where DOWNSCALE_FACTOR can be < 1
|
81 |
+
if target_resolution < base_resolution:
|
82 |
+
target_resolution = base_resolution
|
83 |
+
|
84 |
+
w, h = target_resolution, target_resolution
|
85 |
+
if orientation == "portrait":
|
86 |
+
w = int((((target_resolution**2)/aspect_ratio)**0.5)//compression_factor)*compression_factor
|
87 |
+
h = int((((target_resolution**2)*aspect_ratio)**0.5)//compression_factor)*compression_factor
|
88 |
+
elif orientation == "landscape":
|
89 |
+
w = int((((target_resolution**2)*aspect_ratio)**0.5)//compression_factor)*compression_factor
|
90 |
+
h = int((((target_resolution**2)/aspect_ratio)**0.5)//compression_factor)*compression_factor
|
91 |
+
|
92 |
+
scale_factor = target_resolution/base_resolution
|
93 |
+
denoise_strength = self.calculate_denoise_strength(scale_factor)
|
94 |
+
|
95 |
+
return (w, h, scale_factor, denoise_strength)
|
96 |
+
|
97 |
+
NODE_CLASS_MAPPINGS = {
|
98 |
+
"JDC_AutoImageSize": AutoImageSize
|
99 |
+
}
|
100 |
+
|
101 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
102 |
+
"JDC_AutoImageSize": "Easy Aspect Ratios"
|
103 |
+
}
|