DrawNGuess / MagicQuill /scribble_color_edit.py
LIU, Zichen
Initial Commit
0e84795
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, get_files_with_extension
from .comfyui_utils import CheckpointLoaderSimple, ControlNetLoader, ControlNetApplyAdvanced, CLIPTextEncode, KSampler, VAEDecode, GrowMask, PIDINET_Preprocessor, LineArt_Preprocessor, Color_Preprocessor
class ScribbleColorEditModel():
def __init__(self):
self.checkpoint_loader = CheckpointLoaderSimple()
self.clip_text_encoder = CLIPTextEncode()
self.mask_processor = GrowMask()
self.controlnet_loader = ControlNetLoader()
self.scribble_processor = PIDINET_Preprocessor()
self.lineart_processor = LineArt_Preprocessor()
self.color_processor = Color_Preprocessor()
self.brushnet_loader = BrushNetLoader()
self.brushnet_node = BrushNet()
self.controlnet_apply = ControlNetApplyAdvanced()
self.ksampler = KSampler()
self.vae_decoder = VAEDecode()
self.blender = BlendInpaint()
self.ckpt_name = "SD1.5/realisticVisionV60B1_v51VAE.safetensors"
with torch.no_grad():
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(self.ckpt_name)
self.load_models('SD1.5', 'float16')
def load_models(self, base_model_version="SD1.5", dtype='float16'):
if base_model_version == "SD1.5":
edge_controlnet_name = "control_v11p_sd15_scribble.safetensors"
color_controlnet_name = "color_finetune.safetensors"
brushnet_name = "brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors"
# elif base_model_version == "SDXL":
# edge_controlnet_name = "controlnet-scribble-sdxl-1.0.safetensors"
# color_controlnet_name = "colorGridControlnet_v10.safetensors"
# brushnet_name = "brushnet_xl/random_mask_brushnet_ckpt_sdxl_v0/diffusion_pytorch_model.safetensors"
else:
raise ValueError("Invalid base_model_version, not supported yet!!!: {}".format(base_model_version))
self.edge_controlnet = self.controlnet_loader.load_controlnet(edge_controlnet_name)[0]
self.color_controlnet = self.controlnet_loader.load_controlnet(color_controlnet_name)[0]
self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048):
if ckpt_name != self.ckpt_name:
self.ckpt_name = ckpt_name
with torch.no_grad():
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
self.load_models(base_model_version, dtype)
# 根据基础模型版本加载相应的 ControlNet&BrushNet 模型
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
# Grow Mask for Color Editing
mask = self.mask_processor.expand_mask(mask, expand=grow_size, tapered_corners=True)[0]
# Realistic Lineart
image_copy = image.clone()
if stroke_as_edge == "disable":
bool_add_mask = add_mask > 0.5
mean_brightness = image_copy[bool_add_mask].mean()
if mean_brightness > 0.8:
image_copy[bool_add_mask] = 0.0
else:
image_copy[bool_add_mask] = 1.0
if not torch.equal(image, colored_image):
print("Apply color controlnet")
color_output = self.color_processor.execute(colored_image, resolution=palette_resolution)[0]
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.color_controlnet, color_output, color_strength, 0.0, 1.0)
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, 0.8, 0.0, 1.0)
else:
print("Apply edge controlnet")
# Resize masks to match the dimensions of lineart_output
color_output = self.color_processor.execute(image, resolution=palette_resolution)[0]
if fine_edge == "enable":
lineart_output = self.lineart_processor.execute(image, resolution=512, coarse=False)[0]
else:
lineart_output = self.scribble_processor.execute(image, resolution=512)[0]
add_mask_resized = F.interpolate(add_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).unsqueeze(0).float(), size=(1, lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0).squeeze(0)
bool_add_mask_resized = (add_mask_resized > 0.5)
bool_remove_mask_resized = (remove_mask_resized > 0.5)
if stroke_as_edge == "enable":
# 将remove_mask区域的像素变成黑色
lineart_output[bool_remove_mask_resized] = 0.0
# 将add_mask区域的像素变成白色
lineart_output[bool_add_mask_resized] = 1.0
else:
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
positive, negative = self.controlnet_apply.apply_controlnet(positive, negative, self.edge_controlnet, lineart_output, edge_strength, 0.0, 1.0)
# BrushNet
model, positive, negative, latent = self.brushnet_node.model_update(
model=self.model,
vae=self.vae, # 需要根据实际情况提供 VAE 模型
image=image,
mask=mask,
brushnet=self.brushnet,
positive=positive,
negative=negative,
scale=inpaint_strength,
start_at=0,
end_at=10000
)
# KSampler Node
latent_samples = self.ksampler.sample(
model=model,
seed=seed,
steps=steps,
cfg=cfg,
sampler_name=sampler_name,
scheduler=scheduler,
positive=positive,
negative=negative,
latent_image=latent,
)[0]
final_image = self.vae_decoder.decode(self.vae, latent_samples)[0]
final_image = self.blender.blend_inpaint(final_image, image, mask, kernel=10, sigma=10.0)[0]
# Return the final image
return (latent_samples, final_image, lineart_output, color_output)