Spaces:
Running
on
L4
Running
on
L4
import torch.nn.functional as F | |
import torch | |
import numpy as np | |
from PIL import Image | |
import os | |
import sys | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension | |
from .comfyui_utils import CheckpointLoaderSimple, ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode, GrowMask, PIDINET_Preprocessor, LineArt_Preprocessor, Color_Preprocessor | |
class ScribbleColorEditModel(): | |
def __init__(self): | |
self.checkpoint_loader = CheckpointLoaderSimple() | |
self.clip_text_encoder = CLIPTextEncode() | |
self.mask_processor = GrowMask() | |
self.controlnet_loader = ControlNetLoader() | |
self.scribble_processor = PIDINET_Preprocessor() | |
self.lineart_processor = LineArt_Preprocessor() | |
self.color_processor = Color_Preprocessor() | |
self.brushnet_loader = BrushNetLoader() | |
self.brushnet_node = BrushNet() | |
self.controlnet_apply = ControlNetApplyAdvanced() | |
self.ksampler = KSampler() | |
self.vae_decoder = VAEDecode() | |
self.blender = BlendInpaint() | |
self.ckpt_name = "SD1.5/realisticVisionV60B1_v51VAE.safetensors" | |
with torch.no_grad(): | |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(self.ckpt_name) | |
self.load_models('SD1.5', 'float16') | |
def load_models(self, base_model_version="SD1.5", dtype='float16'): | |
if base_model_version == "SD1.5": | |
edge_controlnet_name = "control_v11p_sd15_scribble.safetensors" | |
color_controlnet_name = "color_finetune.safetensors" | |
brushnet_name = "brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors" | |
# elif base_model_version == "SDXL": | |
# edge_controlnet_name = "controlnet-scribble-sdxl-1.0.safetensors" | |
# color_controlnet_name = "colorGridControlnet_v10.safetensors" | |
# brushnet_name = "brushnet_xl/random_mask_brushnet_ckpt_sdxl_v0/diffusion_pytorch_model.safetensors" | |
else: | |
raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version)) | |
self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0] | |
self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0] | |
self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint') | |
print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint')) | |
self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0] | |
def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048): | |
if ckpt_name != self.ckpt_name: | |
self.ckpt_name = ckpt_name | |
with torch.no_grad(): | |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name) | |
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'): | |
self.load_models(base_model_version, dtype) | |
# 根据基础模型版本加载相应的 ControlNet&BrushNet 模型 | |
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0] | |
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0] | |
# Grow Mask for Color Editing | |
mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0] | |
# Realistic Lineart | |
image_copy = image.clone() | |
if stroke_as_edge == "disable": | |
bool_add_mask = add_mask > 0.5 | |
mean_brightness = image_copy[bool_add_mask].mean() | |
if mean_brightness > 0.8: | |
image_copy[bool_add_mask] = 0.0 | |
else: | |
image_copy[bool_add_mask] = 1.0 | |
if not torch.equal(image, colored_image): | |
print("Apply color controlnet") | |
color_output = self.color_processor.execute(colored_image, resolution=palette_resolution)[0] | |
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0] | |
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0) | |
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0) | |
else: | |
print("Apply edge controlnet") | |
# Resize masks to match the dimensions of lineart_output | |
color_output = self.color_processor.execute(image, resolution=palette_resolution)[0] | |
if fine_edge == "enable": | |
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0] | |
else: | |
lineart_output = self.scribble_processor.execute(image, resolution=512)[0] | |
add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0) | |
remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0) | |
bool_add_mask_resized = (add_mask_resized > 0.5) | |
bool_remove_mask_resized = (remove_mask_resized > 0.5) | |
if stroke_as_edge == "enable": | |
# 将remove_mask区域的像素变成黑色 | |
lineart_output[bool_remove_mask_resized] = 0.0 | |
# 将add_mask区域的像素变成白色 | |
lineart_output[bool_add_mask_resized] = 1.0 | |
else: | |
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0 | |
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0) | |
# BrushNet | |
model, positive, negative, latent = self.brushnet_node.model_update( | |
model=self.model, | |
vae=self.vae, # 需要根据实际情况提供 VAE 模型 | |
image=image, | |
mask=mask, | |
brushnet=self.brushnet, | |
positive=positive, | |
negative=negative, | |
scale=inpaint_strength, | |
start_at=0, | |
end_at=10000 | |
) | |
# KSampler Node | |
latent_samples = self.ksampler.sample( | |
model=model, | |
seed=seed, | |
steps=steps, | |
cfg=cfg, | |
sampler_name=sampler_name, | |
scheduler=scheduler, | |
positive=positive, | |
negative=negative, | |
latent_image=latent, | |
)[0] | |
final_image = self.vae_decoder.decode(self.vae, latent_samples)[0] | |
final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0] | |
# Return the final image | |
return (latent_samples, final_image, lineart_output, color_output) | |