|
import io |
|
import json |
|
import urllib.parse |
|
import urllib.request |
|
from math import pi |
|
|
|
import comfy.model_management as model_management |
|
import comfy.utils |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
from ..log import log |
|
from ..utils import ( |
|
EASINGS, |
|
apply_easing, |
|
get_server_info, |
|
numpy_NFOV, |
|
pil2tensor, |
|
tensor2np, |
|
) |
|
|
|
|
|
def get_image(filename, subfolder, folder_type): |
|
log.debug( |
|
f"Getting image {filename} from foldertype {folder_type} {f'in subfolder: {subfolder}' if subfolder else ''}" |
|
) |
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
|
base_url, port = get_server_info() |
|
|
|
url_values = urllib.parse.urlencode(data) |
|
url = f"http://{base_url}:{port}/view?{url_values}" |
|
log.debug(f"Fetching image from {url}") |
|
with urllib.request.urlopen(url) as response: |
|
return io.BytesIO(response.read()) |
|
|
|
|
|
class MTB_ToDevice: |
|
"""Send a image or mask tensor to the given device.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
devices = ["cpu"] |
|
if torch.backends.mps.is_available(): |
|
devices.append("mps") |
|
if torch.cuda.is_available(): |
|
devices.append("cuda") |
|
for i in range(torch.cuda.device_count()): |
|
devices.append(f"cuda{i}") |
|
|
|
return { |
|
"required": { |
|
"ignore_errors": ("BOOLEAN", {"default": False}), |
|
"device": (devices, {"default": "cpu"}), |
|
}, |
|
"optional": { |
|
"image": ("IMAGE",), |
|
"mask": ("MASK",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "MASK") |
|
RETURN_NAMES = ("images", "masks") |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "to_device" |
|
|
|
def to_device( |
|
self, |
|
*, |
|
ignore_errors=False, |
|
device="cuda", |
|
image: torch.Tensor | None = None, |
|
mask: torch.Tensor | None = None, |
|
): |
|
if not ignore_errors and image is None and mask is None: |
|
raise ValueError( |
|
"You must either provide an image or a mask," |
|
" use ignore_error to passthrough" |
|
) |
|
if image is not None: |
|
image = image.to(device) |
|
if mask is not None: |
|
mask = mask.to(device) |
|
return (image, mask) |
|
|
|
|
|
|
|
class MTB_ApplyTextTemplate: |
|
""" |
|
Experimental node to interpolate strings from inputs. |
|
|
|
Interpolation just requires {}, for instance: |
|
|
|
Some string {var_1} and {var_2} |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"template": ("STRING", {"default": "", "multiline": True}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
RETURN_NAMES = ("string",) |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "execute" |
|
|
|
def execute(self, *, template: str, **kwargs): |
|
res = f"{template}" |
|
for k, v in kwargs.items(): |
|
res = res.replace(f"{{{k}}}", f"{v}") |
|
|
|
return (res,) |
|
|
|
|
|
class MTB_MatchDimensions: |
|
"""Match images dimensions along the given dimension, preserving aspect ratio.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"source": ("IMAGE",), |
|
"reference": ("IMAGE",), |
|
"match": (["height", "width"], {"default": "height"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "INT", "INT") |
|
RETURN_NAMES = ("image", "new_width", "new_height") |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "execute" |
|
|
|
def execute( |
|
self, source: torch.Tensor, reference: torch.Tensor, match: str |
|
): |
|
import torchvision.transforms.functional as VF |
|
|
|
_batch_size, height, width, _channels = source.shape |
|
_rbatch_size, rheight, rwidth, _rchannels = reference.shape |
|
|
|
source_aspect_ratio = width / height |
|
|
|
|
|
source = source.permute(0, 3, 1, 2) |
|
reference = reference.permute(0, 3, 1, 2) |
|
|
|
if match == "height": |
|
new_height = rheight |
|
new_width = int(rheight * source_aspect_ratio) |
|
else: |
|
new_width = rwidth |
|
new_height = int(rwidth / source_aspect_ratio) |
|
|
|
resized_images = [ |
|
VF.resize( |
|
source[i], |
|
(new_height, new_width), |
|
antialias=True, |
|
interpolation=Image.BICUBIC, |
|
) |
|
for i in range(_batch_size) |
|
] |
|
resized_source = torch.stack(resized_images, dim=0) |
|
resized_source = resized_source.permute(0, 2, 3, 1) |
|
|
|
return (resized_source, new_width, new_height) |
|
|
|
|
|
class MTB_FloatToFloats: |
|
"""Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"float": ("FLOAT", {"default": 0.0, "forceInput": True}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("FLOATS",) |
|
RETURN_NAMES = ("floats",) |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "convert" |
|
|
|
def convert(self, float: float): |
|
return (float,) |
|
|
|
|
|
class MTB_FloatsToInts: |
|
"""Conversion utility for compatibility with frame interpolation.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"floats": ("FLOATS", {"forceInput": True}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("INTS", "INT") |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "convert" |
|
|
|
def convert(self, floats: list[float]): |
|
vals = [int(x) for x in floats] |
|
return (vals, vals) |
|
|
|
|
|
class MTB_FloatsToFloat: |
|
"""Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"floats": ("FLOATS",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("FLOAT",) |
|
RETURN_NAMES = ("float",) |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "convert" |
|
|
|
def convert(self, floats): |
|
return (floats,) |
|
|
|
|
|
class MTB_AutoPanEquilateral: |
|
"""Generate a 360 panning video from an equilateral image.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"equilateral_image": ("IMAGE",), |
|
"fovX": ("FLOAT", {"default": 45.0}), |
|
"fovY": ("FLOAT", {"default": 45.0}), |
|
"elevation": ("FLOAT", {"default": 0.5}), |
|
"frame_count": ("INT", {"default": 100}), |
|
"width": ("INT", {"default": 768}), |
|
"height": ("INT", {"default": 512}), |
|
}, |
|
"optional": { |
|
"floats_fovX": ("FLOATS",), |
|
"floats_fovY": ("FLOATS",), |
|
"floats_elevation": ("FLOATS",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("image",) |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "generate_frames" |
|
|
|
def check_floats(self, f: list[float] | None, expected_count: int): |
|
if f: |
|
if len(f) == expected_count: |
|
return True |
|
return False |
|
return True |
|
|
|
def generate_frames( |
|
self, |
|
equilateral_image: torch.Tensor, |
|
fovX: float, |
|
fovY: float, |
|
elevation: float, |
|
frame_count: int, |
|
width: int, |
|
height: int, |
|
floats_fovX: list[float] | None = None, |
|
floats_fovY: list[float] | None = None, |
|
floats_elevation: list[float] | None = None, |
|
): |
|
source = tensor2np(equilateral_image) |
|
|
|
if len(source) > 1: |
|
log.warn( |
|
"You provided more than one image in the equilateral_image input, only the first will be used." |
|
) |
|
if not all( |
|
[ |
|
self.check_floats(x, frame_count) |
|
for x in [floats_fovX, floats_fovY, floats_elevation] |
|
] |
|
): |
|
raise ValueError( |
|
"You provided less than the expected number of fovX, fovY, or elevation values." |
|
) |
|
|
|
source = source[0] |
|
frames = [] |
|
|
|
pbar = comfy.utils.ProgressBar(frame_count) |
|
for i in range(frame_count): |
|
rotation_angle = (i / frame_count) * 2 * pi |
|
|
|
if floats_elevation: |
|
elevation = floats_elevation[i] |
|
|
|
if floats_fovX: |
|
fovX = floats_fovX[i] |
|
|
|
if floats_fovY: |
|
fovY = floats_fovY[i] |
|
|
|
fov = [fovX / 100, fovY / 100] |
|
center_point = [rotation_angle / (2 * pi), elevation] |
|
|
|
nfov = numpy_NFOV(fov, height, width) |
|
frame = nfov.to_nfov(source, center_point=center_point) |
|
|
|
frames.append(frame) |
|
|
|
model_management.throw_exception_if_processing_interrupted() |
|
pbar.update(1) |
|
|
|
return (pil2tensor(frames),) |
|
|
|
|
|
class MTB_GetBatchFromHistory: |
|
"""Very experimental node to load images from the history of the server. |
|
|
|
Queue items without output are ignored in the count. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"enable": ("BOOLEAN", {"default": True}), |
|
"count": ("INT", {"default": 1, "min": 0}), |
|
"offset": ("INT", {"default": 0, "min": -1e9, "max": 1e9}), |
|
"internal_count": ("INT", {"default": 0}), |
|
}, |
|
"optional": { |
|
"passthrough_image": ("IMAGE",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("images",) |
|
CATEGORY = "mtb/animation" |
|
FUNCTION = "load_from_history" |
|
|
|
def load_from_history( |
|
self, |
|
*, |
|
enable=True, |
|
count=0, |
|
offset=0, |
|
internal_count=0, |
|
passthrough_image=None, |
|
): |
|
if not enable or count == 0: |
|
if passthrough_image is not None: |
|
log.debug("Using passthrough image") |
|
return (passthrough_image,) |
|
log.debug("Load from history is disabled for this iteration") |
|
return (torch.zeros(0),) |
|
frames = [] |
|
|
|
base_url, port = get_server_info() |
|
|
|
history_url = f"http://{base_url}:{port}/history" |
|
log.debug(f"Fetching history from {history_url}") |
|
output = torch.zeros(0) |
|
with urllib.request.urlopen(history_url) as response: |
|
output = self.load_batch_frames(response, offset, count, frames) |
|
|
|
if output.size(0) == 0: |
|
log.warn("No output found in history") |
|
|
|
return (output,) |
|
|
|
def load_batch_frames(self, response, offset, count, frames): |
|
history = json.loads(response.read()) |
|
|
|
output_images = [] |
|
|
|
for run in history.values(): |
|
for node_output in run["outputs"].values(): |
|
if "images" in node_output: |
|
for image in node_output["images"]: |
|
image_data = get_image( |
|
image["filename"], |
|
image["subfolder"], |
|
image["type"], |
|
) |
|
output_images.append(image_data) |
|
|
|
if not output_images: |
|
return torch.zeros(0) |
|
|
|
|
|
start_index = max(len(output_images) - offset - count, 0) |
|
end_index = len(output_images) - offset |
|
selected_images = output_images[start_index:end_index] |
|
|
|
frames = [Image.open(image) for image in selected_images] |
|
|
|
if not frames: |
|
return torch.zeros(0) |
|
elif len(frames) != count: |
|
log.warning(f"Expected {count} images, got {len(frames)} instead") |
|
|
|
return pil2tensor(frames) |
|
|
|
|
|
class MTB_AnyToString: |
|
"""Tries to take any input and convert it to a string.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": {"input": ("*",)}, |
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
FUNCTION = "do_str" |
|
CATEGORY = "mtb/converters" |
|
|
|
def do_str(self, input): |
|
if isinstance(input, str): |
|
return (input,) |
|
elif isinstance(input, torch.Tensor): |
|
return (f"Tensor of shape {input.shape} and dtype {input.dtype}",) |
|
elif isinstance(input, Image.Image): |
|
return (f"PIL Image of size {input.size} and mode {input.mode}",) |
|
elif isinstance(input, np.ndarray): |
|
return ( |
|
f"Numpy array of shape {input.shape} and dtype {input.dtype}", |
|
) |
|
|
|
elif isinstance(input, dict): |
|
return ( |
|
f"Dictionary of {len(input)} items, with keys {input.keys()}", |
|
) |
|
|
|
else: |
|
log.debug(f"Falling back to string conversion of {input}") |
|
return (str(input),) |
|
|
|
|
|
class MTB_StringReplace: |
|
"""Basic string replacement.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"string": ("STRING", {"forceInput": True}), |
|
"old": ("STRING", {"default": ""}), |
|
"new": ("STRING", {"default": ""}), |
|
} |
|
} |
|
|
|
FUNCTION = "replace_str" |
|
RETURN_TYPES = ("STRING",) |
|
CATEGORY = "mtb/string" |
|
|
|
def replace_str(self, string: str, old: str, new: str): |
|
log.debug(f"Current string: {string}") |
|
log.debug(f"Find string: {old}") |
|
log.debug(f"Replace string: {new}") |
|
|
|
string = string.replace(old, new) |
|
|
|
log.debug(f"New string: {string}") |
|
|
|
return (string,) |
|
|
|
|
|
class MTB_MathExpression: |
|
"""Node to evaluate a simple math expression string""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"expression": ("STRING", {"default": "", "multiline": True}), |
|
} |
|
} |
|
|
|
FUNCTION = "eval_expression" |
|
RETURN_TYPES = ("FLOAT", "INT") |
|
RETURN_NAMES = ("result (float)", "result (int)") |
|
CATEGORY = "mtb/math" |
|
DESCRIPTION = ( |
|
"evaluate a simple math expression string, only supports literal_eval" |
|
) |
|
|
|
def eval_expression(self, expression: str, **kwargs): |
|
from ast import literal_eval |
|
|
|
for key, value in kwargs.items(): |
|
log.debug(f"Replacing placeholder <{key}> with value {value}") |
|
expression = expression.replace(f"<{key}>", str(value)) |
|
|
|
result = -1 |
|
try: |
|
result = literal_eval(expression) |
|
except SyntaxError as e: |
|
raise ValueError( |
|
f"The expression syntax is wrong '{expression}': {e}" |
|
) from e |
|
|
|
except Exception as e: |
|
raise ValueError( |
|
f"Math expression only support literal_eval now: {e}" |
|
) |
|
|
|
return (result, int(result)) |
|
|
|
|
|
class MTB_FitNumber: |
|
"""Fit the input float using a source and target range""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"value": ("FLOAT", {"default": 0, "forceInput": True}), |
|
"clamp": ("BOOLEAN", {"default": False}), |
|
"source_min": ( |
|
"FLOAT", |
|
{"default": 0.0, "step": 0.01, "min": -1e5}, |
|
), |
|
"source_max": ( |
|
"FLOAT", |
|
{"default": 1.0, "step": 0.01, "min": -1e5}, |
|
), |
|
"target_min": ( |
|
"FLOAT", |
|
{"default": 0.0, "step": 0.01, "min": -1e5}, |
|
), |
|
"target_max": ( |
|
"FLOAT", |
|
{"default": 1.0, "step": 0.01, "min": -1e5}, |
|
), |
|
"easing": ( |
|
EASINGS, |
|
{"default": "Linear"}, |
|
), |
|
} |
|
} |
|
|
|
FUNCTION = "set_range" |
|
RETURN_TYPES = ("FLOAT",) |
|
CATEGORY = "mtb/math" |
|
DESCRIPTION = "Fit the input float using a source and target range" |
|
|
|
def set_range( |
|
self, |
|
value: float, |
|
clamp: bool, |
|
source_min: float, |
|
source_max: float, |
|
target_min: float, |
|
target_max: float, |
|
easing: str, |
|
): |
|
if source_min == source_max: |
|
normalized_value = 0 |
|
else: |
|
normalized_value = (value - source_min) / (source_max - source_min) |
|
if clamp: |
|
normalized_value = max(min(normalized_value, 1), 0) |
|
|
|
eased_value = apply_easing(normalized_value, easing) |
|
|
|
|
|
res = target_min + (target_max - target_min) * eased_value |
|
|
|
return (res,) |
|
|
|
|
|
class MTB_ConcatImages: |
|
"""Add images to batch.""" |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "concatenate_tensors" |
|
CATEGORY = "mtb/image" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": {"reverse": ("BOOLEAN", {"default": False})}, |
|
"optional": { |
|
"on_mismatch": ( |
|
["Error", "Smallest", "Largest"], |
|
{"default": "Smallest"}, |
|
) |
|
}, |
|
} |
|
|
|
def concatenate_tensors( |
|
self, |
|
reverse: bool, |
|
on_mismatch: str = "Smallest", |
|
**kwargs: torch.Tensor, |
|
) -> tuple[torch.Tensor]: |
|
tensors = list(kwargs.values()) |
|
|
|
if on_mismatch == "Error": |
|
shapes = [tensor.shape for tensor in tensors] |
|
if not all(shape == shapes[0] for shape in shapes): |
|
raise ValueError( |
|
"All input tensors must have the same shape when on_mismatch is 'Error'." |
|
) |
|
|
|
else: |
|
import torch.nn.functional as F |
|
|
|
if on_mismatch == "Smallest": |
|
target_shape = min( |
|
(tensor.shape for tensor in tensors), |
|
key=lambda s: (s[1], s[2]), |
|
) |
|
else: |
|
target_shape = max( |
|
(tensor.shape for tensor in tensors), |
|
key=lambda s: (s[1], s[2]), |
|
) |
|
|
|
target_height, target_width = target_shape[1], target_shape[2] |
|
|
|
resized_tensors = [] |
|
for tensor in tensors: |
|
if ( |
|
tensor.shape[1] != target_height |
|
or tensor.shape[2] != target_width |
|
): |
|
resized_tensor = F.interpolate( |
|
tensor.permute(0, 3, 1, 2), |
|
size=(target_height, target_width), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
resized_tensor = resized_tensor.permute(0, 2, 3, 1) |
|
resized_tensors.append(resized_tensor) |
|
else: |
|
resized_tensors.append(tensor) |
|
|
|
tensors = resized_tensors |
|
|
|
concatenated = torch.cat(tensors, dim=0) |
|
|
|
return (concatenated,) |
|
|
|
|
|
__nodes__ = [ |
|
MTB_StringReplace, |
|
MTB_FitNumber, |
|
MTB_GetBatchFromHistory, |
|
MTB_AnyToString, |
|
MTB_ConcatImages, |
|
MTB_MathExpression, |
|
MTB_ToDevice, |
|
MTB_ApplyTextTemplate, |
|
MTB_MatchDimensions, |
|
MTB_AutoPanEquilateral, |
|
MTB_FloatsToFloat, |
|
MTB_FloatToFloats, |
|
MTB_FloatsToInts, |
|
] |
|
|