|
import itertools |
|
import json |
|
import math |
|
import os |
|
|
|
import comfy.model_management as model_management |
|
import folder_paths |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image, ImageOps |
|
from PIL.PngImagePlugin import PngInfo |
|
from skimage.filters import gaussian |
|
from skimage.util import compare_images |
|
|
|
from ..log import log |
|
from ..utils import np2tensor, pil2tensor, tensor2pil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gaussian_kernel( |
|
kernel_size: int, sigma_x: float, sigma_y: float, device=None |
|
): |
|
x, y = torch.meshgrid( |
|
torch.linspace(-1, 1, kernel_size, device=device), |
|
torch.linspace(-1, 1, kernel_size, device=device), |
|
indexing="ij", |
|
) |
|
d_x = x * x / (2.0 * sigma_x * sigma_x) |
|
d_y = y * y / (2.0 * sigma_y * sigma_y) |
|
g = torch.exp(-(d_x + d_y)) |
|
return g / g.sum() |
|
|
|
|
|
class MTB_CoordinatesToString: |
|
RETURN_TYPES = ("STRING",) |
|
FUNCTION = "convert" |
|
CATEGORY = "mtb/coordinates" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"coordinates": ("BATCH_COORDINATES",), |
|
"frame": ("INT",), |
|
} |
|
} |
|
|
|
def convert( |
|
self, coordinates: list[list[tuple[int, int]]], frame: int |
|
) -> tuple[str]: |
|
frame = max(frame, len(coordinates) - 1) |
|
coords = coordinates[frame] |
|
output: list[dict[str, int]] = [] |
|
|
|
for x, y in coords: |
|
output.append({"x": x, "y": y}) |
|
|
|
return (json.dumps(output),) |
|
|
|
|
|
class MTB_ExtractCoordinatesFromImage: |
|
"""Extract 2D points from a batch of images based on a threshold.""" |
|
|
|
RETURN_TYPES = ("BATCH_COORDINATES", "IMAGE") |
|
FUNCTION = "extract" |
|
CATEGORY = "mtb/coordinates" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"threshold": ("FLOAT",), |
|
"max_points": ("INT", {"default": 50, "min": 0}), |
|
}, |
|
"optional": {"image": ("IMAGE",), "mask": ("MASK",)}, |
|
} |
|
|
|
def extract( |
|
self, |
|
threshold: float, |
|
max_points: int, |
|
image: torch.Tensor | None = None, |
|
mask: torch.Tensor | None = None, |
|
) -> tuple[list[list[tuple[int, int]]], torch.Tensor]: |
|
if image is not None: |
|
batch_count, height, width, channel_count = image.shape |
|
imgs = image |
|
else: |
|
if mask is None: |
|
raise ValueError("Must provide either image or mask") |
|
batch_count, height, width = mask.shape |
|
channel_count = 1 |
|
imgs = mask |
|
|
|
if channel_count not in [1, 2, 3, 4]: |
|
raise ValueError(f"Incorrect channel count: {channel_count}") |
|
|
|
all_points: list[list[tuple[int, int]]] = [] |
|
debug_images = torch.zeros( |
|
(batch_count, height, width, 3), |
|
dtype=torch.uint8, |
|
device=imgs.device, |
|
) |
|
|
|
for i, img in enumerate(imgs): |
|
if channel_count == 1: |
|
alpha_channel = img if len(img.shape) == 2 else img[:, :, 0] |
|
elif channel_count == 2: |
|
alpha_channel = img[:, :, 1] |
|
elif channel_count == 4: |
|
alpha_channel = img[:, :, 3] |
|
else: |
|
|
|
alpha_channel = img[:, :, :3].max(dim=2)[0] |
|
|
|
points = (alpha_channel > threshold).nonzero(as_tuple=False) |
|
|
|
if len(points) > max_points: |
|
indices = torch.randperm(points.size(0), device=img.device)[ |
|
:max_points |
|
] |
|
points = points[indices] |
|
|
|
points = [(int(y.item()), int(x.item())) for x, y in points] |
|
all_points.append(points) |
|
|
|
for x, y in points: |
|
self._draw_circle(debug_images[i], (x, y), 5) |
|
|
|
return (all_points, debug_images) |
|
|
|
@staticmethod |
|
def _draw_circle( |
|
image: torch.Tensor, center: tuple[int, int], radius: int |
|
): |
|
"""Draw a 5px circle on the image.""" |
|
x0, y0 = center |
|
for x in range(-radius, radius + 1): |
|
for y in range(-radius, radius + 1): |
|
in_radius = x**2 + y**2 <= radius**2 |
|
in_bounds = ( |
|
0 <= x0 + x < image.shape[1] |
|
and 0 <= y0 + y < image.shape[0] |
|
) |
|
if in_radius and in_bounds: |
|
image[y0 + y, x0 + x] = torch.tensor( |
|
[255, 255, 255], |
|
dtype=torch.uint8, |
|
device=image.device, |
|
) |
|
|
|
|
|
class MTB_ColorCorrectGPU: |
|
"""Various color correction methods using only Torch.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"force_gpu": ("BOOLEAN", {"default": True}), |
|
"clamp": ([True, False], {"default": True}), |
|
"gamma": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"contrast": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"exposure": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"offset": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"hue": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, |
|
), |
|
"saturation": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"value": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
}, |
|
"optional": {"mask": ("MASK",)}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "correct" |
|
CATEGORY = "mtb/image processing" |
|
|
|
@staticmethod |
|
def get_device(tensor: torch.Tensor, force_gpu: bool): |
|
if force_gpu: |
|
if torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
elif ( |
|
hasattr(torch.backends, "mps") |
|
and torch.backends.mps.is_available() |
|
): |
|
return torch.device("mps") |
|
elif hasattr(torch, "hip") and torch.hip.is_available(): |
|
return torch.device("hip") |
|
return ( |
|
tensor.device |
|
) |
|
|
|
@staticmethod |
|
def rgb_to_hsv(image: torch.Tensor): |
|
r, g, b = image.unbind(-1) |
|
max_rgb, argmax_rgb = image.max(-1) |
|
min_rgb, _ = image.min(-1) |
|
|
|
diff = max_rgb - min_rgb |
|
|
|
h = torch.empty_like(max_rgb) |
|
s = diff / (max_rgb + 1e-7) |
|
v = max_rgb |
|
|
|
h[argmax_rgb == 0] = (g - b)[argmax_rgb == 0] / (diff + 1e-7)[ |
|
argmax_rgb == 0 |
|
] |
|
h[argmax_rgb == 1] = ( |
|
2.0 + (b - r)[argmax_rgb == 1] / (diff + 1e-7)[argmax_rgb == 1] |
|
) |
|
h[argmax_rgb == 2] = ( |
|
4.0 + (r - g)[argmax_rgb == 2] / (diff + 1e-7)[argmax_rgb == 2] |
|
) |
|
h = (h / 6.0) % 1.0 |
|
|
|
h = h.unsqueeze(-1) |
|
s = s.unsqueeze(-1) |
|
v = v.unsqueeze(-1) |
|
|
|
return torch.cat((h, s, v), dim=-1) |
|
|
|
@staticmethod |
|
def hsv_to_rgb(hsv: torch.Tensor): |
|
h, s, v = hsv.unbind(-1) |
|
h = h * 6.0 |
|
|
|
i = torch.floor(h) |
|
f = h - i |
|
p = v * (1.0 - s) |
|
q = v * (1.0 - s * f) |
|
t = v * (1.0 - s * (1.0 - f)) |
|
|
|
i = i.long() % 6 |
|
|
|
mask = torch.stack( |
|
(i == 0, i == 1, i == 2, i == 3, i == 4, i == 5), -1 |
|
) |
|
|
|
rgb = torch.stack( |
|
( |
|
torch.where( |
|
mask[..., 0], |
|
v, |
|
torch.where( |
|
mask[..., 1], |
|
q, |
|
torch.where( |
|
mask[..., 2], |
|
p, |
|
torch.where( |
|
mask[..., 3], |
|
p, |
|
torch.where(mask[..., 4], t, v), |
|
), |
|
), |
|
), |
|
), |
|
torch.where( |
|
mask[..., 0], |
|
t, |
|
torch.where( |
|
mask[..., 1], |
|
v, |
|
torch.where( |
|
mask[..., 2], |
|
v, |
|
torch.where( |
|
mask[..., 3], |
|
q, |
|
torch.where(mask[..., 4], p, p), |
|
), |
|
), |
|
), |
|
), |
|
torch.where( |
|
mask[..., 0], |
|
p, |
|
torch.where( |
|
mask[..., 1], |
|
p, |
|
torch.where( |
|
mask[..., 2], |
|
t, |
|
torch.where( |
|
mask[..., 3], |
|
v, |
|
torch.where(mask[..., 4], v, q), |
|
), |
|
), |
|
), |
|
), |
|
), |
|
dim=-1, |
|
) |
|
|
|
return rgb |
|
|
|
def correct( |
|
self, |
|
image: torch.Tensor, |
|
force_gpu: bool, |
|
clamp: bool, |
|
gamma: float = 1.0, |
|
contrast: float = 1.0, |
|
exposure: float = 0.0, |
|
offset: float = 0.0, |
|
hue: float = 0.0, |
|
saturation: float = 1.0, |
|
value: float = 1.0, |
|
mask: torch.Tensor | None = None, |
|
): |
|
device = self.get_device(image, force_gpu) |
|
image = image.to(device) |
|
|
|
if mask is not None: |
|
if mask.shape[0] != image.shape[0]: |
|
mask = mask.expand(image.shape[0], -1, -1) |
|
|
|
mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) |
|
mask = mask.to(device) |
|
|
|
model_management.throw_exception_if_processing_interrupted() |
|
adjusted = image.pow(1 / gamma) * (2.0**exposure) * contrast + offset |
|
|
|
model_management.throw_exception_if_processing_interrupted() |
|
hsv = self.rgb_to_hsv(adjusted) |
|
hsv[..., 0] = (hsv[..., 0] + hue) % 1.0 |
|
hsv[..., 1] = hsv[..., 1] * saturation |
|
hsv[..., 2] = hsv[..., 2] * value |
|
adjusted = self.hsv_to_rgb(hsv) |
|
|
|
model_management.throw_exception_if_processing_interrupted() |
|
if clamp: |
|
adjusted = torch.clamp(adjusted, 0.0, 1.0) |
|
|
|
|
|
result = ( |
|
adjusted |
|
if mask is None |
|
else torch.where(mask > 0, adjusted, image) |
|
) |
|
|
|
if not force_gpu: |
|
result = result.cpu() |
|
|
|
return (result,) |
|
|
|
|
|
class MTB_ColorCorrect: |
|
"""Various color correction methods""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"clamp": ([True, False], {"default": True}), |
|
"gamma": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"contrast": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"exposure": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"offset": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"hue": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, |
|
), |
|
"saturation": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
"value": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, |
|
), |
|
}, |
|
"optional": {"mask": ("MASK",)}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "correct" |
|
CATEGORY = "mtb/image processing" |
|
|
|
@staticmethod |
|
def gamma_correction_tensor(image, gamma): |
|
gamma_inv = 1.0 / gamma |
|
return image.pow(gamma_inv) |
|
|
|
@staticmethod |
|
def contrast_adjustment_tensor(image, contrast): |
|
r, g, b = image.unbind(-1) |
|
|
|
|
|
luminance_image = 0.33 * r + 0.71 * g + 0.06 * b |
|
luminance_mean = torch.mean(luminance_image.unsqueeze(-1)) |
|
|
|
|
|
contrasted = image * contrast + (1.0 - contrast) * luminance_mean |
|
return torch.clamp(contrasted, 0.0, 1.0) |
|
|
|
@staticmethod |
|
def exposure_adjustment_tensor(image, exposure): |
|
return image * (2.0**exposure) |
|
|
|
@staticmethod |
|
def offset_adjustment_tensor(image, offset): |
|
return image + offset |
|
|
|
@staticmethod |
|
def hsv_adjustment(image: torch.Tensor, hue, saturation, value): |
|
images = tensor2pil(image) |
|
out = [] |
|
for img in images: |
|
hsv_image = img.convert("HSV") |
|
|
|
h, s, v = hsv_image.split() |
|
|
|
h = h.point(lambda x: (x + hue * 255) % 256) |
|
s = s.point(lambda x: int(x * saturation)) |
|
v = v.point(lambda x: int(x * value)) |
|
|
|
hsv_image = Image.merge("HSV", (h, s, v)) |
|
rgb_image = hsv_image.convert("RGB") |
|
out.append(rgb_image) |
|
return pil2tensor(out) |
|
|
|
@staticmethod |
|
def hsv_adjustment_tensor_not_working( |
|
image: torch.Tensor, hue, saturation, value |
|
): |
|
"""Abandonning for now""" |
|
image = image.squeeze(0).permute(2, 0, 1) |
|
|
|
max_val, _ = image.max(dim=0, keepdim=True) |
|
min_val, _ = image.min(dim=0, keepdim=True) |
|
delta = max_val - min_val |
|
|
|
hue_image = torch.zeros_like(max_val) |
|
mask = delta != 0.0 |
|
|
|
r, g, b = image[0], image[1], image[2] |
|
hue_image[mask & (max_val == r)] = ((g - b) / delta)[ |
|
mask & (max_val == r) |
|
] % 6.0 |
|
hue_image[mask & (max_val == g)] = ((b - r) / delta)[ |
|
mask & (max_val == g) |
|
] + 2.0 |
|
hue_image[mask & (max_val == b)] = ((r - g) / delta)[ |
|
mask & (max_val == b) |
|
] + 4.0 |
|
|
|
saturation_image = delta / (max_val + 1e-7) |
|
value_image = max_val |
|
|
|
hue_image = (hue_image + hue) % 1.0 |
|
saturation_image = torch.where( |
|
mask, saturation * saturation_image, saturation_image |
|
) |
|
value_image = value * value_image |
|
|
|
c = value_image * saturation_image |
|
x = c * (1 - torch.abs((hue_image % 2) - 1)) |
|
m = value_image - c |
|
|
|
prime_image = torch.zeros_like(image) |
|
prime_image[0] = torch.where( |
|
max_val == r, c, torch.where(max_val == g, x, prime_image[0]) |
|
) |
|
prime_image[1] = torch.where( |
|
max_val == r, x, torch.where(max_val == g, c, prime_image[1]) |
|
) |
|
prime_image[2] = torch.where( |
|
max_val == g, x, torch.where(max_val == b, c, prime_image[2]) |
|
) |
|
|
|
rgb_image = prime_image + m |
|
|
|
rgb_image = rgb_image.permute(1, 2, 0).unsqueeze(0) |
|
|
|
return rgb_image |
|
|
|
def correct( |
|
self, |
|
image: torch.Tensor, |
|
clamp: bool, |
|
gamma: float = 1.0, |
|
contrast: float = 1.0, |
|
exposure: float = 0.0, |
|
offset: float = 0.0, |
|
hue: float = 0.0, |
|
saturation: float = 1.0, |
|
value: float = 1.0, |
|
mask: torch.Tensor | None = None, |
|
): |
|
if mask is not None: |
|
if mask.shape[0] != image.shape[0]: |
|
mask = mask.expand(image.shape[0], -1, -1) |
|
|
|
mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) |
|
|
|
|
|
adjusted = self.gamma_correction_tensor(image, gamma) |
|
adjusted = self.contrast_adjustment_tensor(adjusted, contrast) |
|
adjusted = self.exposure_adjustment_tensor(adjusted, exposure) |
|
adjusted = self.offset_adjustment_tensor(adjusted, offset) |
|
adjusted = self.hsv_adjustment(adjusted, hue, saturation, value) |
|
|
|
if clamp: |
|
adjusted = torch.clamp(image, 0.0, 1.0) |
|
|
|
result = ( |
|
adjusted |
|
if mask is None |
|
else torch.where(mask > 0, adjusted, image) |
|
) |
|
|
|
return (result,) |
|
|
|
|
|
class MTB_ImageCompare: |
|
"""Compare two images and return a difference image""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"imageA": ("IMAGE",), |
|
"imageB": ("IMAGE",), |
|
"mode": ( |
|
["checkerboard", "diff", "blend"], |
|
{"default": "checkerboard"}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "compare" |
|
CATEGORY = "mtb/image" |
|
|
|
def compare(self, imageA: torch.Tensor, imageB: torch.Tensor, mode): |
|
if imageA.dim() == 4: |
|
batch_count = imageA.size(0) |
|
return ( |
|
torch.cat( |
|
tuple( |
|
self.compare(imageA[i], imageB[i], mode)[0] |
|
for i in range(batch_count) |
|
), |
|
dim=0, |
|
), |
|
) |
|
|
|
num_channels_A = imageA.size(2) |
|
num_channels_B = imageB.size(2) |
|
|
|
|
|
if num_channels_A == 3 and num_channels_B == 4: |
|
imageA = torch.cat( |
|
(imageA, torch.ones_like(imageA[:, :, 0:1])), dim=2 |
|
) |
|
elif num_channels_B == 3 and num_channels_A == 4: |
|
imageB = torch.cat( |
|
(imageB, torch.ones_like(imageB[:, :, 0:1])), dim=2 |
|
) |
|
match mode: |
|
case "diff": |
|
compare_image = torch.abs(imageA - imageB) |
|
case "blend": |
|
compare_image = 0.5 * (imageA + imageB) |
|
case "checkerboard": |
|
imageA = imageA.numpy() |
|
imageB = imageB.numpy() |
|
compared_channels = [ |
|
torch.from_numpy( |
|
compare_images( |
|
imageA[:, :, i], imageB[:, :, i], method=mode |
|
) |
|
) |
|
for i in range(imageA.shape[2]) |
|
] |
|
|
|
compare_image = torch.stack(compared_channels, dim=2) |
|
case _: |
|
compare_image = None |
|
raise ValueError(f"Unknown mode {mode}") |
|
|
|
compare_image = compare_image.unsqueeze(0) |
|
|
|
return (compare_image,) |
|
|
|
|
|
import requests |
|
|
|
|
|
class MTB_LoadImageFromUrl: |
|
"""Load an image from the given URL""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"url": ( |
|
"STRING", |
|
{ |
|
"default": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" |
|
}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "load" |
|
CATEGORY = "mtb/IO" |
|
|
|
def load(self, url): |
|
|
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image = ImageOps.exif_transpose(image) |
|
return (pil2tensor(image),) |
|
|
|
|
|
class MTB_Blur: |
|
"""Blur an image using a Gaussian filter.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"sigmaX": ( |
|
"FLOAT", |
|
{"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, |
|
), |
|
"sigmaY": ( |
|
"FLOAT", |
|
{"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, |
|
), |
|
}, |
|
"optional": {"sigmasX": ("FLOATS",), "sigmasY": ("FLOATS",)}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "blur" |
|
CATEGORY = "mtb/image processing" |
|
|
|
def blur( |
|
self, image: torch.Tensor, sigmaX, sigmaY, sigmasX=None, sigmasY=None |
|
): |
|
image_np = image.numpy() * 255 |
|
|
|
blurred_images = [] |
|
if sigmasX is not None: |
|
if sigmasY is None: |
|
sigmasY = sigmasX |
|
if len(sigmasX) != image.size(0): |
|
raise ValueError( |
|
f"SigmasX must have same length as image, sigmasX is {len(sigmasX)} but the batch size is {image.size(0)}" |
|
) |
|
|
|
for i in range(image.size(0)): |
|
blurred = gaussian( |
|
image_np[i], |
|
sigma=(sigmasX[i], sigmasY[i], 0), |
|
channel_axis=2, |
|
) |
|
blurred_images.append(blurred) |
|
|
|
image_np = np.array(blurred_images) |
|
else: |
|
for i in range(image.size(0)): |
|
blurred = gaussian( |
|
image_np[i], sigma=(sigmaX, sigmaY, 0), channel_axis=2 |
|
) |
|
blurred_images.append(blurred) |
|
|
|
image_np = np.array(blurred_images) |
|
return (np2tensor(image_np).squeeze(0),) |
|
|
|
|
|
class MTB_Sharpen: |
|
"""Sharpens an image using a Gaussian kernel.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"sharpen_radius": ( |
|
"INT", |
|
{"default": 1, "min": 1, "max": 31, "step": 1}, |
|
), |
|
"sigma_x": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, |
|
), |
|
"sigma_y": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, |
|
), |
|
"alpha": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1}, |
|
), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "do_sharp" |
|
CATEGORY = "mtb/image processing" |
|
|
|
def do_sharp( |
|
self, |
|
image: torch.Tensor, |
|
sharpen_radius: int, |
|
sigma_x: float, |
|
sigma_y: float, |
|
alpha: float, |
|
): |
|
if sharpen_radius == 0: |
|
return (image,) |
|
|
|
channels = image.shape[3] |
|
|
|
kernel_size = 2 * sharpen_radius + 1 |
|
kernel = gaussian_kernel(kernel_size, sigma_x, sigma_y) * -(alpha * 10) |
|
|
|
|
|
center = kernel_size // 2 |
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 |
|
|
|
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) |
|
tensor_image = image.permute(0, 3, 1, 2) |
|
|
|
tensor_image = F.pad( |
|
tensor_image, |
|
(sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), |
|
"reflect", |
|
) |
|
sharpened = F.conv2d( |
|
tensor_image, kernel, padding=center, groups=channels |
|
) |
|
|
|
|
|
sharpened = sharpened[ |
|
:, |
|
:, |
|
sharpen_radius:-sharpen_radius, |
|
sharpen_radius:-sharpen_radius, |
|
] |
|
|
|
sharpened = sharpened.permute(0, 2, 3, 1) |
|
result = torch.clamp(sharpened, 0, 1) |
|
|
|
return (result,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MTB_MaskToImage: |
|
"""Converts a mask (alpha) to an RGB image with a color and background""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"mask": ("MASK",), |
|
"color": ("COLOR",), |
|
"background": ("COLOR", {"default": "#000000"}), |
|
}, |
|
"optional": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
}, |
|
} |
|
|
|
CATEGORY = "mtb/generate" |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
|
FUNCTION = "render_mask" |
|
|
|
def render_mask(self, mask, color, background, invert=False): |
|
masks = tensor2pil(1.0 - mask) if invert else tensor2pil(mask) |
|
images = [] |
|
|
|
for m in masks: |
|
_mask = m.convert("L") |
|
|
|
log.debug( |
|
f"Converted mask to PIL Image format, size: {_mask.size}" |
|
) |
|
|
|
image = Image.new("RGBA", _mask.size, color=color) |
|
|
|
image = Image.composite( |
|
image, Image.new("RGBA", _mask.size, color=background), _mask |
|
) |
|
|
|
|
|
|
|
|
|
|
|
images.append(image.convert("RGB")) |
|
|
|
return (pil2tensor(images),) |
|
|
|
|
|
class MTB_ColoredImage: |
|
"""Constant color image of given size.""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"color": ("COLOR",), |
|
"width": ("INT", {"default": 512, "min": 16, "max": 8160}), |
|
"height": ("INT", {"default": 512, "min": 16, "max": 8160}), |
|
}, |
|
"optional": { |
|
"foreground_image": ("IMAGE",), |
|
"foreground_mask": ("MASK",), |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"mask_opacity": ( |
|
"FLOAT", |
|
{"default": 1.0, "step": 0.1, "min": 0}, |
|
), |
|
}, |
|
} |
|
|
|
CATEGORY = "mtb/generate" |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
|
FUNCTION = "render_img" |
|
|
|
def resize_and_crop(self, img: Image.Image, target_size: tuple[int, int]): |
|
scale = max(target_size[0] / img.width, target_size[1] / img.height) |
|
new_size = (int(img.width * scale), int(img.height * scale)) |
|
img = img.resize(new_size, Image.LANCZOS) |
|
left = (img.width - target_size[0]) // 2 |
|
top = (img.height - target_size[1]) // 2 |
|
return img.crop( |
|
(left, top, left + target_size[0], top + target_size[1]) |
|
) |
|
|
|
def resize_and_crop_thumbnails( |
|
self, img: Image.Image, target_size: tuple[int, int] |
|
): |
|
img.thumbnail(target_size, Image.LANCZOS) |
|
left = (img.width - target_size[0]) / 2 |
|
top = (img.height - target_size[1]) / 2 |
|
right = (img.width + target_size[0]) / 2 |
|
bottom = (img.height + target_size[1]) / 2 |
|
return img.crop((left, top, right, bottom)) |
|
|
|
@staticmethod |
|
def process_mask( |
|
mask: torch.Tensor | None, |
|
invert: bool, |
|
|
|
batch_size: int, |
|
) -> list[Image.Image] | None: |
|
if mask is None: |
|
return [None] * batch_size |
|
|
|
masks = tensor2pil(mask if not invert else 1.0 - mask) |
|
|
|
if len(masks) == 1 and batch_size > 1: |
|
masks = masks * batch_size |
|
|
|
if len(masks) != batch_size: |
|
raise ValueError( |
|
"Foreground image and mask must have the same batch size" |
|
) |
|
|
|
return masks |
|
|
|
def render_img( |
|
self, |
|
color: str, |
|
width: int, |
|
height: int, |
|
foreground_image: torch.Tensor | None = None, |
|
foreground_mask: torch.Tensor | None = None, |
|
invert: bool = False, |
|
mask_opacity: float = 1.0, |
|
) -> tuple[torch.Tensor]: |
|
background = Image.new("RGBA", (width, height), color=color) |
|
|
|
if foreground_image is None: |
|
return (pil2tensor([background.convert("RGB")]),) |
|
|
|
fg_images = tensor2pil(foreground_image) |
|
fg_masks = self.process_mask(foreground_mask, invert, len(fg_images)) |
|
|
|
output: list[Image.Image] = [] |
|
for fg_image, fg_mask in zip(fg_images, fg_masks, strict=False): |
|
fg_image = self.resize_and_crop(fg_image, background.size) |
|
|
|
if fg_mask: |
|
fg_mask = self.resize_and_crop(fg_mask, background.size) |
|
|
|
fg_mask_array = np.array(fg_mask) |
|
fg_mask_array = (fg_mask_array * mask_opacity).astype(np.uint8) |
|
fg_mask = Image.fromarray(fg_mask_array) |
|
output.append( |
|
Image.composite( |
|
fg_image.convert("RGBA"), background, fg_mask |
|
).convert("RGB") |
|
) |
|
else: |
|
if fg_image.mode != "RGBA": |
|
raise ValueError( |
|
f"Foreground image must be in 'RGBA' mode when no mask is provided, got {fg_image.mode}" |
|
) |
|
output.append( |
|
Image.alpha_composite(background, fg_image).convert("RGB") |
|
) |
|
|
|
return (pil2tensor(output),) |
|
|
|
|
|
class MTB_ImagePremultiply: |
|
"""Premultiply image with mask""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"mask": ("MASK",), |
|
"invert": ("BOOLEAN", {"default": False}), |
|
} |
|
} |
|
|
|
CATEGORY = "mtb/image" |
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("RGBA",) |
|
FUNCTION = "premultiply" |
|
|
|
def premultiply(self, image, mask, invert): |
|
images = tensor2pil(image) |
|
masks = tensor2pil(mask) if invert else tensor2pil(1.0 - mask) |
|
single = len(mask) == 1 |
|
masks = [x.convert("L") for x in masks] |
|
|
|
out = [] |
|
for i, img in enumerate(images): |
|
cur_mask = masks[0] if single else masks[i] |
|
|
|
img.putalpha(cur_mask) |
|
out.append(img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return (pil2tensor(out),) |
|
|
|
|
|
class MTB_ImageResizeFactor: |
|
"""Extracted mostly from WAS Node Suite, with a few edits (most notably multiple image support) and less features.""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"factor": ( |
|
"FLOAT", |
|
{"default": 2, "min": 0.01, "max": 16.0, "step": 0.01}, |
|
), |
|
"supersample": ("BOOLEAN", {"default": True}), |
|
"resampling": ( |
|
[ |
|
"nearest", |
|
"linear", |
|
"bilinear", |
|
"bicubic", |
|
"trilinear", |
|
"area", |
|
"nearest-exact", |
|
], |
|
{"default": "nearest"}, |
|
), |
|
}, |
|
"optional": { |
|
"mask": ("MASK",), |
|
}, |
|
} |
|
|
|
CATEGORY = "mtb/image" |
|
RETURN_TYPES = ("IMAGE", "MASK") |
|
FUNCTION = "resize" |
|
|
|
def resize( |
|
self, |
|
image: torch.Tensor, |
|
factor: float, |
|
supersample: bool, |
|
resampling: str, |
|
mask=None, |
|
): |
|
|
|
if len(image.shape) not in [3, 4]: |
|
raise ValueError( |
|
"Expected image tensor of shape (H, W, C) or (B, H, W, C)" |
|
) |
|
|
|
|
|
if len(image.shape) == 3: |
|
image = image.permute(2, 0, 1).unsqueeze(0) |
|
else: |
|
image = image.permute(0, 3, 1, 2) |
|
|
|
|
|
B, C, H, W = image.shape |
|
new_H, new_W = int(H * factor), int(W * factor) |
|
|
|
align_corner_filters = ("linear", "bilinear", "bicubic", "trilinear") |
|
|
|
resized_image = F.interpolate( |
|
image, |
|
size=(new_H, new_W), |
|
mode=resampling, |
|
align_corners=resampling in align_corner_filters, |
|
) |
|
|
|
|
|
if supersample: |
|
resized_image = F.interpolate( |
|
resized_image, |
|
scale_factor=2, |
|
mode=resampling, |
|
align_corners=resampling in align_corner_filters, |
|
) |
|
|
|
|
|
if len(image.shape) == 4: |
|
resized_image = resized_image.permute(0, 2, 3, 1) |
|
else: |
|
resized_image = resized_image.squeeze(0).permute(1, 2, 0) |
|
|
|
|
|
if mask is not None: |
|
if len(mask.shape) != len(resized_image.shape): |
|
raise ValueError( |
|
"Mask tensor should have the same dimensions as the image tensor" |
|
) |
|
resized_image = resized_image * mask |
|
|
|
return (resized_image,) |
|
|
|
|
|
class MTB_SaveImageGrid: |
|
"""Save all the images in the input batch as a grid of images.""" |
|
|
|
def __init__(self): |
|
self.output_dir = folder_paths.get_output_directory() |
|
self.type = "output" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"filename_prefix": ("STRING", {"default": "ComfyUI"}), |
|
"save_intermediate": ("BOOLEAN", {"default": False}), |
|
}, |
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, |
|
} |
|
|
|
RETURN_TYPES = () |
|
FUNCTION = "save_images" |
|
|
|
OUTPUT_NODE = True |
|
|
|
CATEGORY = "mtb/IO" |
|
|
|
def create_image_grid(self, image_list): |
|
total_images = len(image_list) |
|
|
|
|
|
grid_size = ( |
|
int(math.sqrt(total_images)), |
|
int(math.ceil(math.sqrt(total_images))), |
|
) |
|
|
|
|
|
image_width, image_height = image_list[0].size |
|
|
|
|
|
grid_width = grid_size[0] * image_width |
|
grid_height = grid_size[1] * image_height |
|
grid_image = Image.new("RGB", (grid_width, grid_height)) |
|
|
|
|
|
for i, image in enumerate(image_list): |
|
x = (i % grid_size[0]) * image_width |
|
y = (i // grid_size[0]) * image_height |
|
grid_image.paste(image, (x, y, x + image_width, y + image_height)) |
|
|
|
return grid_image |
|
|
|
def save_images( |
|
self, |
|
images, |
|
filename_prefix="Grid", |
|
save_intermediate=False, |
|
prompt=None, |
|
extra_pnginfo=None, |
|
): |
|
( |
|
full_output_folder, |
|
filename, |
|
counter, |
|
subfolder, |
|
filename_prefix, |
|
) = folder_paths.get_save_image_path( |
|
filename_prefix, |
|
self.output_dir, |
|
images[0].shape[1], |
|
images[0].shape[0], |
|
) |
|
image_list = [] |
|
batch_counter = counter |
|
|
|
metadata = PngInfo() |
|
if prompt is not None: |
|
metadata.add_text("prompt", json.dumps(prompt)) |
|
if extra_pnginfo is not None: |
|
for x in extra_pnginfo: |
|
metadata.add_text(x, json.dumps(extra_pnginfo[x])) |
|
|
|
for idx, image in enumerate(images): |
|
i = 255.0 * image.cpu().numpy() |
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) |
|
image_list.append(img) |
|
|
|
if save_intermediate: |
|
file = f"{filename}_batch-{idx:03}_{batch_counter:05}_.png" |
|
img.save( |
|
os.path.join(full_output_folder, file), |
|
pnginfo=metadata, |
|
compress_level=4, |
|
) |
|
|
|
batch_counter += 1 |
|
|
|
file = f"{filename}_{counter:05}_.png" |
|
grid = self.create_image_grid(image_list) |
|
grid.save( |
|
os.path.join(full_output_folder, file), |
|
pnginfo=metadata, |
|
compress_level=4, |
|
) |
|
|
|
results = [ |
|
{"filename": file, "subfolder": subfolder, "type": self.type} |
|
] |
|
return {"ui": {"images": results}} |
|
|
|
|
|
class MTB_ImageTileOffset: |
|
"""Mimics an old photoshop technique to check for seamless textures""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"tilesX": ("INT", {"default": 2, "min": 1}), |
|
"tilesY": ("INT", {"default": 2, "min": 1}), |
|
} |
|
} |
|
|
|
CATEGORY = "mtb/generate" |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
|
FUNCTION = "tile_image" |
|
|
|
def tile_image( |
|
self, image: torch.Tensor, tilesX: int = 2, tilesY: int = 2 |
|
): |
|
if tilesX < 1 or tilesY < 1: |
|
raise ValueError("The number of tiles must be at least 1.") |
|
|
|
batch_size, height, width, channels = image.shape |
|
tile_height = height // tilesY |
|
tile_width = width // tilesX |
|
|
|
output_image = torch.zeros_like(image) |
|
|
|
for i, j in itertools.product(range(tilesY), range(tilesX)): |
|
start_h = i * tile_height |
|
end_h = start_h + tile_height |
|
start_w = j * tile_width |
|
end_w = start_w + tile_width |
|
|
|
tile = image[:, start_h:end_h, start_w:end_w, :] |
|
|
|
output_start_h = (i + 1) % tilesY * tile_height |
|
output_start_w = (j + 1) % tilesX * tile_width |
|
output_end_h = output_start_h + tile_height |
|
output_end_w = output_start_w + tile_width |
|
|
|
output_image[ |
|
:, output_start_h:output_end_h, output_start_w:output_end_w, : |
|
] = tile |
|
|
|
return (output_image,) |
|
|
|
|
|
__nodes__ = [ |
|
MTB_ColorCorrect, |
|
MTB_ColorCorrectGPU, |
|
MTB_ImageCompare, |
|
MTB_ImageTileOffset, |
|
MTB_Blur, |
|
|
|
MTB_MaskToImage, |
|
MTB_ColoredImage, |
|
MTB_ImagePremultiply, |
|
MTB_ImageResizeFactor, |
|
MTB_SaveImageGrid, |
|
MTB_LoadImageFromUrl, |
|
MTB_Sharpen, |
|
MTB_ExtractCoordinatesFromImage, |
|
MTB_CoordinatesToString, |
|
] |
|
|