|
import folder_paths
|
|
|
|
import impact.mmdet_nodes as mmdet_nodes
|
|
from impact.utils import *
|
|
from impact.core import SEG
|
|
import impact.core as core
|
|
import nodes
|
|
|
|
class NO_BBOX_MODEL:
|
|
pass
|
|
|
|
|
|
class NO_SEGM_MODEL:
|
|
pass
|
|
|
|
|
|
class MMDetLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")]
|
|
segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")]
|
|
return {"required": {"model_name": (bboxs + segms, )}}
|
|
RETURN_TYPES = ("BBOX_MODEL", "SEGM_MODEL")
|
|
FUNCTION = "load_mmdet"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
def load_mmdet(self, model_name):
|
|
mmdet_path = folder_paths.get_full_path("mmdets", model_name)
|
|
model = mmdet_nodes.load_mmdet(mmdet_path)
|
|
|
|
if model_name.startswith("bbox"):
|
|
return model, NO_SEGM_MODEL()
|
|
else:
|
|
return NO_BBOX_MODEL(), model
|
|
|
|
|
|
class BboxDetectorForEach:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"bbox_model": ("BBOX_MODEL", ),
|
|
"image": ("IMAGE", ),
|
|
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
|
|
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("SEGS", )
|
|
FUNCTION = "doit"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
@staticmethod
|
|
def detect(bbox_model, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
|
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold)
|
|
segmasks = core.create_segmasks(mmdet_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
for x in segmasks:
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
y1, x1, y2, x2 = item_bbox
|
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size:
|
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
|
|
cropped_image = crop_image(image, crop_region)
|
|
cropped_mask = crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
|
|
items.append(item)
|
|
|
|
shape = h, w
|
|
return shape, items
|
|
|
|
def doit(self, bbox_model, image, threshold, dilation, crop_factor):
|
|
return (BboxDetectorForEach.detect(bbox_model, image, threshold, dilation, crop_factor), )
|
|
|
|
|
|
class SegmDetectorCombined:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"segm_model": ("SEGM_MODEL", ),
|
|
"image": ("IMAGE", ),
|
|
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dilation": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "doit"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
def doit(self, segm_model, image, threshold, dilation):
|
|
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold)
|
|
segmasks = core.create_segmasks(mmdet_results)
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
mask = combine_masks(segmasks)
|
|
return (mask,)
|
|
|
|
|
|
class BboxDetectorCombined(SegmDetectorCombined):
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"bbox_model": ("BBOX_MODEL", ),
|
|
"image": ("IMAGE", ),
|
|
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dilation": ("INT", {"default": 4, "min": 0, "max": 255, "step": 1}),
|
|
}
|
|
}
|
|
|
|
def doit(self, bbox_model, image, threshold, dilation):
|
|
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold)
|
|
segmasks = core.create_segmasks(mmdet_results)
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
mask = combine_masks(segmasks)
|
|
return (mask,)
|
|
|
|
|
|
class SegmDetectorForEach:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"segm_model": ("SEGM_MODEL", ),
|
|
"image": ("IMAGE", ),
|
|
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
|
|
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("SEGS", )
|
|
FUNCTION = "doit"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
def doit(self, segm_model, image, threshold, dilation, crop_factor):
|
|
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold)
|
|
segmasks = core.create_segmasks(mmdet_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
for x in segmasks:
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
|
|
cropped_image = crop_image(image, crop_region)
|
|
cropped_mask = crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
|
|
items.append(item)
|
|
|
|
shape = h,w
|
|
return ((shape, items), )
|
|
|
|
|
|
class SegsMaskCombine:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"segs": ("SEGS", ),
|
|
"image": ("IMAGE", ),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "doit"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
@staticmethod
|
|
def combine(segs, image):
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
|
for seg in segs[1]:
|
|
cropped_mask = seg.cropped_mask
|
|
crop_region = seg.crop_region
|
|
mask[crop_region[1]:crop_region[3], crop_region[0]:crop_region[2]] |= (cropped_mask * 255).astype(np.uint8)
|
|
|
|
return torch.from_numpy(mask.astype(np.float32) / 255.0)
|
|
|
|
def doit(self, segs, image):
|
|
return (SegsMaskCombine.combine(segs, image), )
|
|
|
|
|
|
class MaskPainter(nodes.PreviewImage):
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {"images": ("IMAGE",), },
|
|
"hidden": {
|
|
"prompt": "PROMPT",
|
|
"extra_pnginfo": "EXTRA_PNGINFO",
|
|
},
|
|
"optional": {"mask_image": ("IMAGE_PATH",), },
|
|
"optional": {"image": (["#placeholder"], )},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "save_painted_images"
|
|
|
|
CATEGORY = "ImpactPack/Legacy"
|
|
|
|
def save_painted_images(self, images, filename_prefix="impact-mask",
|
|
prompt=None, extra_pnginfo=None, mask_image=None, image=None):
|
|
if image == "#placeholder" or image['image_hash'] != id(images):
|
|
|
|
res = self.save_images(images, filename_prefix, prompt, extra_pnginfo)
|
|
|
|
item = res['ui']['images'][0]
|
|
|
|
if not item['filename'].endswith(']'):
|
|
filepath = f"{item['filename']} [{item['type']}]"
|
|
else:
|
|
filepath = item['filename']
|
|
|
|
_, mask = nodes.LoadImage().load_image(filepath)
|
|
|
|
res['ui']['aux'] = [id(images), res['ui']['images']]
|
|
res['result'] = (mask, )
|
|
|
|
return res
|
|
|
|
else:
|
|
|
|
if '0' in image:
|
|
image = image['0']
|
|
|
|
forward = {'filename': image['forward_filename'],
|
|
'subfolder': image['forward_subfolder'],
|
|
'type': image['forward_type'], }
|
|
|
|
res = {'ui': {'images': [forward]}}
|
|
|
|
imgpath = ""
|
|
if 'subfolder' in image and image['subfolder'] != "":
|
|
imgpath = image['subfolder'] + "/"
|
|
|
|
imgpath += f"{image['filename']}"
|
|
|
|
if 'type' in image and image['type'] != "":
|
|
imgpath += f" [{image['type']}]"
|
|
|
|
res['ui']['aux'] = [id(images), [forward]]
|
|
_, mask = nodes.LoadImage().load_image(imgpath)
|
|
res['result'] = (mask, )
|
|
|
|
return res
|
|
|