|
import torch |
|
|
|
from ..log import log |
|
|
|
|
|
class MTB_StackImages: |
|
"""Stack the input images horizontally or vertically.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return {"required": {"vertical": ("BOOLEAN", {"default": False})}} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "stack" |
|
CATEGORY = "mtb/image utils" |
|
|
|
def stack(self, vertical, **kwargs): |
|
if not kwargs: |
|
raise ValueError("At least one tensor must be provided.") |
|
|
|
tensors = list(kwargs.values()) |
|
log.debug( |
|
f"Stacking {len(tensors)} tensors " |
|
f"{'vertically' if vertical else 'horizontally'}" |
|
) |
|
|
|
normalized_tensors = [ |
|
self.normalize_to_rgba(tensor) for tensor in tensors |
|
] |
|
max_batch_size = max(tensor.shape[0] for tensor in normalized_tensors) |
|
normalized_tensors = [ |
|
self.duplicate_frames(tensor, max_batch_size) |
|
for tensor in normalized_tensors |
|
] |
|
|
|
if vertical: |
|
width = normalized_tensors[0].shape[2] |
|
if any(tensor.shape[2] != width for tensor in normalized_tensors): |
|
raise ValueError( |
|
"All tensors must have the same width " |
|
"for vertical stacking." |
|
) |
|
dim = 1 |
|
else: |
|
height = normalized_tensors[0].shape[1] |
|
if any(tensor.shape[1] != height for tensor in normalized_tensors): |
|
raise ValueError( |
|
"All tensors must have the same height " |
|
"for horizontal stacking." |
|
) |
|
dim = 2 |
|
|
|
stacked_tensor = torch.cat(normalized_tensors, dim=dim) |
|
|
|
return (stacked_tensor,) |
|
|
|
def normalize_to_rgba(self, tensor): |
|
"""Normalize tensor to have 4 channels (RGBA).""" |
|
_, _, _, channels = tensor.shape |
|
|
|
if channels == 4: |
|
return tensor |
|
|
|
elif channels == 3: |
|
alpha_channel = torch.ones( |
|
tensor.shape[:-1] + (1,), device=tensor.device |
|
) |
|
return torch.cat((tensor, alpha_channel), dim=-1) |
|
else: |
|
raise ValueError( |
|
"Tensor has an unsupported number of channels: " |
|
"expected 3 (RGB) or 4 (RGBA)." |
|
) |
|
|
|
def duplicate_frames(self, tensor, target_batch_size): |
|
"""Duplicate frames in tensor to match the target batch size.""" |
|
current_batch_size = tensor.shape[0] |
|
if current_batch_size < target_batch_size: |
|
duplication_factors: int = target_batch_size // current_batch_size |
|
duplicated_tensor = tensor.repeat(duplication_factors, 1, 1, 1) |
|
remaining_frames = target_batch_size % current_batch_size |
|
if remaining_frames > 0: |
|
duplicated_tensor = torch.cat( |
|
(duplicated_tensor, tensor[:remaining_frames]), dim=0 |
|
) |
|
return duplicated_tensor |
|
else: |
|
return tensor |
|
|
|
|
|
class MTB_PickFromBatch: |
|
"""Pick a specific number of images from a batch. |
|
|
|
either from the start or end. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"from_direction": (["end", "start"], {"default": "start"}), |
|
"count": ("INT", {"default": 1}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "pick_from_batch" |
|
CATEGORY = "mtb/image utils" |
|
|
|
def pick_from_batch(self, image, from_direction, count): |
|
batch_size = image.size(0) |
|
|
|
|
|
count = min(count, batch_size) |
|
if count < batch_size: |
|
log.warning( |
|
f"Requested {count} images, " |
|
f"but only {batch_size} are available." |
|
) |
|
|
|
if from_direction == "end": |
|
selected_tensors = image[-count:] |
|
else: |
|
selected_tensors = image[:count] |
|
|
|
return (selected_tensors,) |
|
|
|
|
|
__nodes__ = [MTB_StackImages, MTB_PickFromBatch] |
|
|