Spaces:
Paused
Paused
import numpy as np | |
import scipy.ndimage | |
import torch | |
import comfy.utils | |
from nodes import MAX_RESOLUTION | |
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): | |
source = source.to(destination.device) | |
if resize_source: | |
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") | |
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) | |
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) | |
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) | |
left, top = (x // multiplier, y // multiplier) | |
right, bottom = (left + source.shape[3], top + source.shape[2],) | |
if mask is None: | |
mask = torch.ones_like(source) | |
else: | |
mask = mask.to(destination.device, copy=True) | |
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") | |
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) | |
# calculate the bounds of the source that will be overlapping the destination | |
# this prevents the source trying to overwrite latent pixels that are out of bounds | |
# of the destination | |
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) | |
mask = mask[:, :, :visible_height, :visible_width] | |
inverse_mask = torch.ones_like(mask) - mask | |
source_portion = mask * source[:, :, :visible_height, :visible_width] | |
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] | |
destination[:, :, top:bottom, left:right] = source_portion + destination_portion | |
return destination | |
class LatentCompositeMasked: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"destination": ("LATENT",), | |
"source": ("LATENT",), | |
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), | |
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), | |
"resize_source": ("BOOLEAN", {"default": False}), | |
}, | |
"optional": { | |
"mask": ("MASK",), | |
} | |
} | |
RETURN_TYPES = ("LATENT",) | |
FUNCTION = "composite" | |
CATEGORY = "latent" | |
def composite(self, destination, source, x, y, resize_source, mask = None): | |
output = destination.copy() | |
destination = destination["samples"].clone() | |
source = source["samples"] | |
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) | |
return (output,) | |
class ImageCompositeMasked: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"destination": ("IMAGE",), | |
"source": ("IMAGE",), | |
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"resize_source": ("BOOLEAN", {"default": False}), | |
}, | |
"optional": { | |
"mask": ("MASK",), | |
} | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "composite" | |
CATEGORY = "image" | |
def composite(self, destination, source, x, y, resize_source, mask = None): | |
destination = destination.clone().movedim(-1, 1) | |
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) | |
return (output,) | |
class MaskToImage: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "mask_to_image" | |
def mask_to_image(self, mask): | |
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) | |
return (result,) | |
class ImageToMask: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"channel": (["red", "green", "blue", "alpha"],), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "image_to_mask" | |
def image_to_mask(self, image, channel): | |
channels = ["red", "green", "blue", "alpha"] | |
mask = image[:, :, :, channels.index(channel)] | |
return (mask,) | |
class ImageColorToMask: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "image_to_mask" | |
def image_to_mask(self, image, color): | |
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) | |
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] | |
mask = torch.where(temp == color, 255, 0).float() | |
return (mask,) | |
class SolidMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | |
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "solid" | |
def solid(self, value, width, height): | |
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") | |
return (out,) | |
class InvertMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "invert" | |
def invert(self, mask): | |
out = 1.0 - mask | |
return (out,) | |
class CropMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | |
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "crop" | |
def crop(self, mask, x, y, width, height): | |
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) | |
out = mask[:, y:y + height, x:x + width] | |
return (out,) | |
class MaskComposite: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"destination": ("MASK",), | |
"source": ("MASK",), | |
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "combine" | |
def combine(self, destination, source, x, y, operation): | |
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() | |
source = source.reshape((-1, source.shape[-2], source.shape[-1])) | |
left, top = (x, y,) | |
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) | |
visible_width, visible_height = (right - left, bottom - top,) | |
source_portion = source[:, :visible_height, :visible_width] | |
destination_portion = destination[:, top:bottom, left:right] | |
if operation == "multiply": | |
output[:, top:bottom, left:right] = destination_portion * source_portion | |
elif operation == "add": | |
output[:, top:bottom, left:right] = destination_portion + source_portion | |
elif operation == "subtract": | |
output[:, top:bottom, left:right] = destination_portion - source_portion | |
elif operation == "and": | |
output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() | |
elif operation == "or": | |
output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() | |
elif operation == "xor": | |
output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() | |
output = torch.clamp(output, 0.0, 1.0) | |
return (output,) | |
class FeatherMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "feather" | |
def feather(self, mask, left, top, right, bottom): | |
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() | |
left = min(left, output.shape[-1]) | |
right = min(right, output.shape[-1]) | |
top = min(top, output.shape[-2]) | |
bottom = min(bottom, output.shape[-2]) | |
for x in range(left): | |
feather_rate = (x + 1.0) / left | |
output[:, :, x] *= feather_rate | |
for x in range(right): | |
feather_rate = (x + 1) / right | |
output[:, :, -x] *= feather_rate | |
for y in range(top): | |
feather_rate = (y + 1) / top | |
output[:, y, :] *= feather_rate | |
for y in range(bottom): | |
feather_rate = (y + 1) / bottom | |
output[:, -y, :] *= feather_rate | |
return (output,) | |
class GrowMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), | |
"tapered_corners": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "expand_mask" | |
def expand_mask(self, mask, expand, tapered_corners): | |
c = 0 if tapered_corners else 1 | |
kernel = np.array([[c, 1, c], | |
[1, 1, 1], | |
[c, 1, c]]) | |
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) | |
out = [] | |
for m in mask: | |
output = m.numpy() | |
for _ in range(abs(expand)): | |
if expand < 0: | |
output = scipy.ndimage.grey_erosion(output, footprint=kernel) | |
else: | |
output = scipy.ndimage.grey_dilation(output, footprint=kernel) | |
output = torch.from_numpy(output) | |
out.append(output) | |
return (torch.stack(out, dim=0),) | |
class ThresholdMask: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "image_to_mask" | |
def image_to_mask(self, mask, value): | |
mask = (mask > value).float() | |
return (mask,) | |
NODE_CLASS_MAPPINGS = { | |
"LatentCompositeMasked": LatentCompositeMasked, | |
"ImageCompositeMasked": ImageCompositeMasked, | |
"MaskToImage": MaskToImage, | |
"ImageToMask": ImageToMask, | |
"ImageColorToMask": ImageColorToMask, | |
"SolidMask": SolidMask, | |
"InvertMask": InvertMask, | |
"CropMask": CropMask, | |
"MaskComposite": MaskComposite, | |
"FeatherMask": FeatherMask, | |
"GrowMask": GrowMask, | |
"ThresholdMask": ThresholdMask, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"ImageToMask": "Convert Image to Mask", | |
"MaskToImage": "Convert Mask to Image", | |
} | |