k4d3 commited on
Commit
1a98ccf
·
1 Parent(s): fbeebb5
caption/joy_single.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Simplified JoyCaption - Generates captions for a single image input
6
+ """
7
+
8
+ import os
9
+ import argparse
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ import torch
13
+ import torchvision.transforms.functional as TVF
14
+ from transformers import (
15
+ AutoModel,
16
+ AutoTokenizer,
17
+ AutoModelForCausalLM,
18
+ )
19
+ from torch import nn
20
+ import logging
21
+
22
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
23
+ MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
24
+ CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
25
+
26
+ class ImageAdapter(nn.Module):
27
+ def __init__(
28
+ self,
29
+ input_features: int,
30
+ output_features: int,
31
+ ln1: bool,
32
+ pos_emb: bool,
33
+ num_image_tokens: int,
34
+ deep_extract: bool,
35
+ ):
36
+ super().__init__()
37
+ self.deep_extract = deep_extract
38
+ if self.deep_extract:
39
+ input_features = input_features * 5
40
+
41
+ self.linear1 = nn.Linear(input_features, output_features)
42
+ self.activation = nn.GELU()
43
+ self.linear2 = nn.Linear(output_features, output_features)
44
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
45
+ self.pos_emb = (
46
+ None
47
+ if not pos_emb
48
+ else nn.Parameter(torch.zeros(num_image_tokens, input_features))
49
+ )
50
+ self.other_tokens = nn.Embedding(3, output_features)
51
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
52
+
53
+ def forward(self, vision_outputs):
54
+ if self.deep_extract:
55
+ x = torch.concat(
56
+ (
57
+ vision_outputs[-2],
58
+ vision_outputs[3],
59
+ vision_outputs[7],
60
+ vision_outputs[13],
61
+ vision_outputs[20],
62
+ ),
63
+ dim=-1,
64
+ )
65
+ else:
66
+ x = vision_outputs[-2]
67
+
68
+ x = self.ln1(x)
69
+ if self.pos_emb is not None:
70
+ x = x + self.pos_emb
71
+
72
+ x = self.linear1(x)
73
+ x = self.activation(x)
74
+ x = self.linear2(x)
75
+
76
+ other_tokens = self.other_tokens(
77
+ torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(
78
+ x.shape[0], -1
79
+ )
80
+ )
81
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
82
+ return x
83
+
84
+ class SimpleCaptioner:
85
+ def __init__(self):
86
+ self.clip_model = None
87
+ self.text_model = None
88
+ self.image_adapter = None
89
+ self.tokenizer = None
90
+
91
+ def load_models(self):
92
+ logging.info("Loading CLIP")
93
+ self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
94
+ self.clip_model = self.clip_model.vision_model
95
+
96
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
97
+ checkpoint = torch.load(
98
+ CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
99
+ )
100
+ checkpoint = {
101
+ k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
102
+ }
103
+ self.clip_model.load_state_dict(checkpoint)
104
+
105
+ self.clip_model.eval()
106
+ self.clip_model.requires_grad_(False)
107
+ self.clip_model.to("cuda")
108
+
109
+ logging.info("Loading tokenizer and LLM")
110
+ self.tokenizer = AutoTokenizer.from_pretrained(
111
+ CHECKPOINT_PATH / "text_model", use_fast=True
112
+ )
113
+
114
+ if (CHECKPOINT_PATH / "text_model").exists():
115
+ self.text_model = AutoModelForCausalLM.from_pretrained(
116
+ CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
117
+ )
118
+ else:
119
+ self.text_model = AutoModelForCausalLM.from_pretrained(
120
+ MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16
121
+ )
122
+ self.text_model.eval()
123
+
124
+ logging.info("Loading image adapter")
125
+ self.image_adapter = ImageAdapter(
126
+ self.clip_model.config.hidden_size,
127
+ self.text_model.config.hidden_size,
128
+ False,
129
+ False,
130
+ 38,
131
+ False,
132
+ )
133
+ self.image_adapter.load_state_dict(
134
+ torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
135
+ )
136
+ self.image_adapter.eval()
137
+ self.image_adapter.to("cuda")
138
+
139
+ @torch.no_grad()
140
+ def generate_caption(self, image_path: str) -> str:
141
+ # Load and preprocess image
142
+ input_image = Image.open(image_path).convert("RGB")
143
+ image = input_image.resize((384, 384), Image.LANCZOS)
144
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
145
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]).to("cuda")
146
+
147
+ # Generate image embeddings
148
+ vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True)
149
+ embedded_images = self.image_adapter(vision_outputs.hidden_states)
150
+
151
+ # Prepare prompt
152
+ prompt = "Write a descriptive caption for this image in a formal tone."
153
+ convo = [
154
+ {"role": "system", "content": "You are a helpful image captioner."},
155
+ {"role": "user", "content": prompt},
156
+ ]
157
+ convo_string = self.tokenizer.apply_chat_template(
158
+ convo, tokenize=False, add_generation_prompt=True
159
+ )
160
+
161
+ # Tokenize and prepare inputs
162
+ convo_tokens = self.tokenizer.encode(
163
+ convo_string, return_tensors="pt", add_special_tokens=False
164
+ )
165
+ prompt_tokens = self.tokenizer.encode(
166
+ prompt, return_tensors="pt", add_special_tokens=False
167
+ )
168
+
169
+ eot_id_indices = (
170
+ (convo_tokens == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
171
+ .nonzero(as_tuple=True)[0]
172
+ .tolist()
173
+ )
174
+ preamble_len = eot_id_indices[1] - prompt_tokens.shape[1]
175
+
176
+ convo_embeds = self.text_model.model.embed_tokens(convo_tokens.to("cuda"))
177
+
178
+ input_embeds = torch.cat(
179
+ [
180
+ convo_embeds[:, :preamble_len],
181
+ embedded_images.to(dtype=convo_embeds.dtype),
182
+ convo_embeds[:, preamble_len:],
183
+ ],
184
+ dim=1,
185
+ )
186
+
187
+ input_ids = torch.cat(
188
+ [
189
+ convo_tokens[:, :preamble_len],
190
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device="cuda"),
191
+ convo_tokens[:, preamble_len:],
192
+ ],
193
+ dim=1,
194
+ )
195
+
196
+ attention_mask = torch.ones_like(input_ids)
197
+
198
+ # Generate caption
199
+ generate_ids = self.text_model.generate(
200
+ input_ids,
201
+ inputs_embeds=input_embeds,
202
+ attention_mask=attention_mask,
203
+ max_new_tokens=300,
204
+ do_sample=True,
205
+ repetition_penalty=1.2,
206
+ )
207
+
208
+ # Decode caption
209
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
210
+ if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
211
+ generate_ids = generate_ids[:, :-1]
212
+
213
+ caption = self.tokenizer.batch_decode(
214
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
215
+ )[0]
216
+
217
+ return caption.strip()
218
+
219
+ def main():
220
+ parser = argparse.ArgumentParser(description="Generate a caption for a single image")
221
+ parser.add_argument("image_path", type=str, help="Path to the input image")
222
+ args = parser.parse_args()
223
+
224
+ # Setup logging
225
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
226
+
227
+ # Initialize and load the captioner
228
+ captioner = SimpleCaptioner()
229
+ captioner.load_models()
230
+
231
+ # Generate and print caption
232
+ caption = captioner.generate_caption(args.image_path)
233
+ print(f"\nGenerated caption:\n{caption}")
234
+
235
+ if __name__ == "__main__":
236
+ main()
comfy_nodes/deep_shrink_mk2.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.utils
3
+
4
+ class PatchModelAddDownscale_v2:
5
+ """A UNet model patch that implements dynamic latent downscaling with gradual transition.
6
+
7
+ This node is an enhanced version of the original PatchModelAddDownscale that adds smooth
8
+ transition capabilities. It operates in three phases:
9
+
10
+ 1. Full Downscale (start_percent → end_percent):
11
+ Latents are downscaled by the specified downscale_factor
12
+
13
+ 2. Gradual Transition (end_percent → gradual_percent):
14
+ Latents smoothly transition from downscaled size back to original size
15
+
16
+ 3. Original Size (after gradual_percent):
17
+ Latents remain at their original size
18
+
19
+ The gradual transition helps prevent abrupt changes in the generation process,
20
+ potentially leading to more consistent results.
21
+
22
+ Parameters:
23
+ model: The model to patch
24
+ block_number: Which UNet block to apply the patch to
25
+ downscale_factor: How much to shrink the latents by
26
+ start_percent: When to start downscaling (in terms of sampling progress)
27
+ end_percent: When to begin transitioning back to original size
28
+ gradual_percent: When to complete the transition to original size
29
+ downscale_after_skip: Whether to apply downscaling after skip connections
30
+ downscale_method: Algorithm to use for downscaling
31
+ upscale_method: Algorithm to use for upscaling
32
+
33
+ Code by:
34
+ - https://github.com/Jordach
35
+ """
36
+
37
+ upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
38
+
39
+ @classmethod
40
+ def INPUT_TYPES(s):
41
+ return {"required": {
42
+ "model": ("MODEL",),
43
+ "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
44
+ "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
45
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
46
+ "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
47
+ "gradual_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001}),
48
+ "downscale_after_skip": ("BOOLEAN", {"default": True}),
49
+ "downscale_method": (s.upscale_methods,),
50
+ "upscale_method": (s.upscale_methods,),
51
+ }}
52
+
53
+ RETURN_TYPES = ("MODEL",)
54
+ FUNCTION = "patch"
55
+ CATEGORY = "model_patches/unet"
56
+
57
+ def calculate_upscale_factor(self, current_percent, end_percent, gradual_percent, downscale_factor):
58
+ """Calculate the upscale factor during the gradual resize phase"""
59
+ if current_percent <= end_percent:
60
+ return 1.0 / downscale_factor # Still fully downscaled
61
+ elif current_percent >= gradual_percent:
62
+ return 1.0 # Fully back to original size
63
+ else:
64
+ # Linear interpolation between downscaled and original size
65
+ progress = (current_percent - end_percent) / (gradual_percent - end_percent)
66
+ scale_diff = 1.0 - (1.0 / downscale_factor)
67
+ return (1.0 / downscale_factor) + (scale_diff * progress)
68
+
69
+ def patch(self, model, block_number, downscale_factor, start_percent, end_percent,
70
+ gradual_percent, downscale_after_skip, downscale_method, upscale_method):
71
+ model_sampling = model.get_model_object("model_sampling")
72
+ sigma_start = model_sampling.percent_to_sigma(start_percent)
73
+ sigma_end = model_sampling.percent_to_sigma(end_percent)
74
+ sigma_rescale = model_sampling.percent_to_sigma(gradual_percent)
75
+
76
+ def input_block_patch(h, transformer_options):
77
+ if downscale_factor == 1:
78
+ return h
79
+
80
+ if transformer_options["block"][1] == block_number:
81
+ sigma = transformer_options["sigmas"][0].item()
82
+
83
+ # Normal downscale behavior between start_percent and end_percent
84
+ if sigma <= sigma_start and sigma >= sigma_end:
85
+ h = comfy.utils.common_upscale(
86
+ h,
87
+ round(h.shape[-1] * (1.0 / downscale_factor)),
88
+ round(h.shape[-2] * (1.0 / downscale_factor)),
89
+ downscale_method,
90
+ "disabled"
91
+ )
92
+ # Gradually upscale latent after end_percent until gradual_percent
93
+ elif sigma < sigma_end and sigma >= sigma_rescale:
94
+ scale_factor = self.calculate_upscale_factor(
95
+ sigma, sigma_rescale, sigma_end, downscale_factor
96
+ )
97
+ h = comfy.utils.common_upscale(
98
+ h,
99
+ round(h.shape[-1] * scale_factor),
100
+ round(h.shape[-2] * scale_factor),
101
+ upscale_method,
102
+ "disabled"
103
+ )
104
+ return h
105
+
106
+ def output_block_patch(h, hsp, transformer_options):
107
+ if h.shape[2] != hsp.shape[2]:
108
+ h = comfy.utils.common_upscale(
109
+ h, hsp.shape[-1], hsp.shape[-2],
110
+ upscale_method, "disabled"
111
+ )
112
+ return h, hsp
113
+
114
+ m = model.clone()
115
+ if downscale_after_skip:
116
+ m.set_model_input_block_patch_after_skip(input_block_patch)
117
+ else:
118
+ m.set_model_input_block_patch(input_block_patch)
119
+ m.set_model_output_block_patch(output_block_patch)
120
+ return (m, )
121
+
122
+ NODE_CLASS_MAPPINGS = {
123
+ "PatchModelAddDownscale_v2": PatchModelAddDownscale_v2,
124
+ }
125
+
126
+ NODE_DISPLAY_NAME_MAPPINGS = {
127
+ # Sampling
128
+ "PatchModelAddDownscale_v2": "PatchModelAddDownscale v2",
129
+ }
comfy_nodes/easy_aspects.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ class AutoImageSize:
4
+ """A utility node that automatically calculates optimal image dimensions and parameters.
5
+
6
+ This node helps create properly scaled images while maintaining desired aspect ratios
7
+ and managing performance through compression factors. It also automatically adjusts
8
+ denoise strength based on the output resolution.
9
+
10
+ Features:
11
+ - Maintains exact aspect ratios while ensuring dimensions are divisible by the compression factor
12
+ - Automatically calculates appropriate denoise strength based on resolution scaling
13
+ - Supports both portrait and landscape orientations
14
+ - Prevents downscaling below base resolution
15
+
16
+ Parameters:
17
+ aspect_ratio: The desired width/height ratio (1.0 = square, >1 = wider/taller)
18
+ orientation: Whether the image should be portrait or landscape
19
+ target_resolution: The desired maximum dimension in pixels
20
+ base_resolution: The model's native resolution (usually 1024)
21
+ compression_factor: Ensures dimensions are divisible by this value (usually 8 for VAEs)
22
+
23
+ Returns:
24
+ WIDTH: The calculated image width
25
+ HEIGHT: The calculated image height
26
+ DOWNSCALE_FACTOR: The scaling factor relative to base_resolution
27
+ DENOISE_STRENGTH: Automatically adjusted denoise strength (0.1-0.65)
28
+
29
+ The denoise strength calculation uses an exponential decay curve fitted to known good values:
30
+ - 1.0x (1024px) → 0.75
31
+ - 1.5x (1536px) → 0.45
32
+ - 2.0x (2048px) → 0.2
33
+
34
+ Code by:
35
+ - https://github.com/Jordach
36
+ """
37
+
38
+ @classmethod
39
+ def INPUT_TYPES(s):
40
+ return {
41
+ "required": {
42
+ "aspect_ratio": ("FLOAT", {"default": 1, "min": 1, "max": 8, "step": 0.01}),
43
+ "orientation": (["portrait", "landscape"],),
44
+ "target_resolution": ("INT", {"default": 1024, "min": 256, "max": 1024*8, "step": 1}),
45
+ "base_resolution": ("INT", {"default": 1024, "min": 256, "max": 1024*8, "step": 1}),
46
+ "compression_factor": ("INT", {"default": 8, "min": 1, "max": 64, "step": 1}),
47
+ }
48
+ }
49
+
50
+ RETURN_TYPES = ("INT", "INT", "FLOAT", "FLOAT")
51
+ RETURN_NAMES = ("WIDTH", "HEIGHT", "DOWNSCALE_FACTOR", "DENOISE_STRENGTH")
52
+ FUNCTION = "create_res"
53
+
54
+ CATEGORY = "utils"
55
+
56
+ def calculate_denoise_strength(self, scale_factor):
57
+ """
58
+ Calculate appropriate denoise strength based on resolution scale factor.
59
+ Uses exponential decay curve fitted to known good values:
60
+ - 1.0x (1024px) → 0.75
61
+ - 1.5x (1536px) → 0.45
62
+ - 2.0x (2048px) → 0.2
63
+ """
64
+ # Base denoise value for 1024px (scale_factor = 1.0)
65
+ base_denoise = 0.95
66
+
67
+ # Calculate denoise strength using exponential decay
68
+ # Formula: denoise = base_denoise * e^(-k * (scale_factor - 1))
69
+ # where k is calculated to fit our known points
70
+ # Decay constant fitted to match reference points
71
+ k = 1.55
72
+
73
+ denoise = base_denoise * math.exp(-k * (scale_factor - 1))
74
+ d_min = 0.1
75
+ d_max = 0.65
76
+ # Clamp the result between 0.1 and 0.6
77
+ return max(d_min, min(d_max, denoise))
78
+
79
+ def create_res(self, aspect_ratio, orientation, target_resolution, base_resolution, compression_factor):
80
+ # Prevent cases where DOWNSCALE_FACTOR can be < 1
81
+ if target_resolution < base_resolution:
82
+ target_resolution = base_resolution
83
+
84
+ w, h = target_resolution, target_resolution
85
+ if orientation == "portrait":
86
+ w = int((((target_resolution**2)/aspect_ratio)**0.5)//compression_factor)*compression_factor
87
+ h = int((((target_resolution**2)*aspect_ratio)**0.5)//compression_factor)*compression_factor
88
+ elif orientation == "landscape":
89
+ w = int((((target_resolution**2)*aspect_ratio)**0.5)//compression_factor)*compression_factor
90
+ h = int((((target_resolution**2)/aspect_ratio)**0.5)//compression_factor)*compression_factor
91
+
92
+ scale_factor = target_resolution/base_resolution
93
+ denoise_strength = self.calculate_denoise_strength(scale_factor)
94
+
95
+ return (w, h, scale_factor, denoise_strength)
96
+
97
+ NODE_CLASS_MAPPINGS = {
98
+ "JDC_AutoImageSize": AutoImageSize
99
+ }
100
+
101
+ NODE_DISPLAY_NAME_MAPPINGS = {
102
+ "JDC_AutoImageSize": "Easy Aspect Ratios"
103
+ }