|
import torch |
|
import torchvision.transforms.v2 as T |
|
import torch.nn.functional as F |
|
from .utils import expand_mask |
|
|
|
class LoadCLIPSegModels: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": {}, |
|
} |
|
|
|
RETURN_TYPES = ("CLIP_SEG",) |
|
FUNCTION = "execute" |
|
CATEGORY = "essentials/segmentation" |
|
|
|
def execute(self): |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
return ((processor, model),) |
|
|
|
class ApplyCLIPSeg: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"clip_seg": ("CLIP_SEG",), |
|
"image": ("IMAGE",), |
|
"prompt": ("STRING", { "multiline": False, "default": "" }), |
|
"threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }), |
|
"smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }), |
|
"dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }), |
|
"blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "execute" |
|
CATEGORY = "essentials/segmentation" |
|
|
|
def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur): |
|
processor, model = clip_seg |
|
|
|
imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy() |
|
|
|
outputs = [] |
|
for i in imagenp: |
|
inputs = processor(text=prompt, images=[i], return_tensors="pt") |
|
out = model(**inputs, interpolate_pos_encoding=True) |
|
out = out.logits.unsqueeze(1) |
|
out = torch.sigmoid(out[0][0]) |
|
out = (out > threshold) |
|
outputs.append(out) |
|
|
|
del imagenp |
|
|
|
outputs = torch.stack(outputs, dim=0) |
|
|
|
if smooth > 0: |
|
if smooth % 2 == 0: |
|
smooth += 1 |
|
outputs = T.functional.gaussian_blur(outputs, smooth) |
|
|
|
outputs = outputs.float() |
|
|
|
if dilate != 0: |
|
outputs = expand_mask(outputs, dilate, True) |
|
|
|
if blur > 0: |
|
if blur % 2 == 0: |
|
blur += 1 |
|
outputs = T.functional.gaussian_blur(outputs, blur) |
|
|
|
|
|
outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1) |
|
|
|
return (outputs,) |
|
|
|
SEG_CLASS_MAPPINGS = { |
|
"ApplyCLIPSeg+": ApplyCLIPSeg, |
|
"LoadCLIPSegModels+": LoadCLIPSegModels, |
|
} |
|
|
|
SEG_NAME_MAPPINGS = { |
|
"ApplyCLIPSeg+": "🔧 Apply CLIPSeg", |
|
"LoadCLIPSegModels+": "🔧 Load CLIPSeg Models", |
|
} |