|
import cv2 |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
|
|
from ..utils import models_dir, np2tensor |
|
|
|
|
|
|
|
|
|
|
|
class MTB_LoadVitMatteModel: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"kind": (("Composition-1K", "Distinctions-646"),), |
|
"autodownload": ("BOOLEAN", {"default": True}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("VITMATTE_MODEL",) |
|
RETURN_NAMES = ("torch_script",) |
|
CATEGORY = "mtb/vitmatte" |
|
FUNCTION = "execute" |
|
|
|
def execute(self, *, kind: str, autodownload: bool): |
|
dest = models_dir / "vitmatte" |
|
dest.mkdir(exist_ok=True) |
|
name = "dist" if kind == "Distinctions-646" else "com" |
|
|
|
file = hf_hub_download( |
|
repo_id="melmass/pytorch-scripts", |
|
filename=f"vitmatte_b_{name}.pt", |
|
local_dir=dest.as_posix(), |
|
local_files_only=not autodownload, |
|
) |
|
model = torch.jit.load(file).to("cuda") |
|
|
|
return (model,) |
|
|
|
|
|
class MTB_GenerateTrimap: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
|
|
"mask": ("MASK",), |
|
"erode": ("INT", {"default": 10}), |
|
"dilate": ("INT", {"default": 10}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("trimap",) |
|
|
|
CATEGORY = "mtb/vitmatte" |
|
FUNCTION = "execute" |
|
|
|
def execute( |
|
self, |
|
|
|
mask: torch.Tensor, |
|
erode: int = 10, |
|
dilate: int = 10, |
|
): |
|
|
|
|
|
|
|
mask = mask.to("cuda").half() |
|
|
|
trimaps = [] |
|
for m in mask: |
|
mask_arr = m.squeeze(0).to(torch.uint8).cpu().numpy() * 255 |
|
erode_kernel = np.ones((erode, erode), np.uint8) |
|
dilate_kernel = np.ones((dilate, dilate), np.uint8) |
|
eroded = cv2.erode(mask_arr, erode_kernel, iterations=5) |
|
dilated = cv2.dilate(mask_arr, dilate_kernel, iterations=5) |
|
trimap = np.zeros_like(mask_arr) |
|
trimap[dilated == 255] = 128 |
|
trimap[eroded == 255] = 255 |
|
trimaps.append(trimap) |
|
|
|
return (np2tensor(trimaps),) |
|
|
|
|
|
class MTB_ApplyVitMatte: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"model": ("VITMATTE_MODEL",), |
|
"image": ("IMAGE",), |
|
"trimap": ("IMAGE",), |
|
"returns": (("RGB", "RGBA"),), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "MASK") |
|
RETURN_NAMES = ("image (rgba)", "mask") |
|
CATEGORY = "mtb/utils" |
|
FUNCTION = "execute" |
|
|
|
def execute( |
|
self, model, image: torch.Tensor, trimap: torch.Tensor, returns: str |
|
): |
|
im_count = image.shape[0] |
|
tm_count = trimap.shape[0] |
|
|
|
if im_count != tm_count: |
|
raise ValueError("image and trimap must have the same batch size") |
|
|
|
outputs_m: list[torch.Tensor] = [] |
|
outputs_i: list[torch.Tensor] = [] |
|
for i, im in enumerate(image): |
|
tm = trimap[i].half().unsqueeze(2).permute(2, 0, 1).to("cuda") |
|
im = im.half().permute(2, 0, 1).to("cuda") |
|
|
|
inputs = {"image": im.unsqueeze(0), "trimap": tm.unsqueeze(0)} |
|
|
|
fine_mask = model(inputs) |
|
foreground = im * fine_mask + (1 - fine_mask) |
|
|
|
if returns == "RGBA": |
|
rgba_image = torch.cat( |
|
(foreground, fine_mask.unsqueeze(0)), dim=0 |
|
) |
|
outputs_i.append(rgba_image.unsqueeze(0)) |
|
else: |
|
outputs_i.append(foreground.unsqueeze(0)) |
|
|
|
outputs_m.append(fine_mask.unsqueeze(0)) |
|
|
|
result_m = torch.cat(outputs_m, dim=0) |
|
result_i = torch.cat(outputs_i, dim=0) |
|
|
|
return (result_i.permute(0, 2, 3, 1), result_m) |
|
|
|
|
|
__nodes__ = [MTB_LoadVitMatteModel, MTB_GenerateTrimap, MTB_ApplyVitMatte] |
|
|