Spaces:
Runtime error
Runtime error
LIU, Zichen
commited on
Commit
·
ca78dbf
1
Parent(s):
79ecf3f
update
Browse files
MagicQuill/magic_utils.py
CHANGED
@@ -110,14 +110,12 @@ def draw_contour(img, mask):
|
|
110 |
img_np = img_np.astype(np.uint8)
|
111 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
112 |
|
113 |
-
# 膨胀掩码
|
114 |
kernel = np.ones((5, 5), np.uint8)
|
115 |
mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
|
116 |
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
117 |
for contour in contours:
|
118 |
-
cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10)
|
119 |
img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
120 |
-
# 转换回tensor
|
121 |
transform = transforms.ToTensor()
|
122 |
img_tensor = transform(img_np)
|
123 |
|
@@ -128,7 +126,6 @@ def draw_contour(img, mask):
|
|
128 |
def get_colored_contour(img1, img2, threshold=10):
|
129 |
diff = torch.abs(img1 - img2).float()
|
130 |
diff_gray = torch.mean(diff, dim=-1)
|
131 |
-
# 阈值处理以生成二进制掩码
|
132 |
mask = diff_gray > threshold
|
133 |
|
134 |
return draw_contour(img2, mask), mask
|
@@ -153,9 +150,7 @@ def rgb_to_name(rgb_tuple):
|
|
153 |
def find_different_colors(img1, img2, threshold=10):
|
154 |
img1 = img1.to(torch.uint8)
|
155 |
img2 = img2.to(torch.uint8)
|
156 |
-
# 计算图像之间的绝对差异
|
157 |
diff = torch.abs(img1 - img2).float().mean(dim=-1)
|
158 |
-
# 找到大于阈值的差异区域
|
159 |
diff_mask = diff > threshold
|
160 |
diff_indices = torch.nonzero(diff_mask, as_tuple=True)
|
161 |
|
@@ -165,14 +160,10 @@ def find_different_colors(img1, img2, threshold=10):
|
|
165 |
else:
|
166 |
sampled_diff_indices = diff_indices
|
167 |
|
168 |
-
# 提取不同区域的颜色
|
169 |
diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
|
170 |
-
# 将颜色值转换为颜色名称
|
171 |
color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
|
172 |
name_counter = Counter(color_names)
|
173 |
-
# 过滤出现超过10次的颜色
|
174 |
filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
|
175 |
-
# 按出现次数从大到小排序
|
176 |
sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
|
177 |
if len(sorted_color_names) >= 3:
|
178 |
return "colorful"
|
@@ -183,19 +174,15 @@ def get_bounding_box_from_mask(mask, padded=False):
|
|
183 |
# Ensure the mask is a binary mask (0s and 1s)
|
184 |
mask = mask.squeeze()
|
185 |
rows, cols = torch.where(mask > 0.5)
|
186 |
-
# If there are no '1's in the mask, return None or an appropriate bounding box like (0,0,0,0)
|
187 |
if len(rows) == 0 or len(cols) == 0:
|
188 |
return (0, 0, 0, 0)
|
189 |
height, width = mask.shape
|
190 |
if padded:
|
191 |
padded_size = max(width, height)
|
192 |
-
# 检查填充发生在哪个方向
|
193 |
if width < height:
|
194 |
-
# 宽度较小,填充发生在宽度上
|
195 |
offset_x = (padded_size - width) / 2
|
196 |
offset_y = 0
|
197 |
else:
|
198 |
-
# 高度较小,填充发生在高度上
|
199 |
offset_y = (padded_size - height) / 2
|
200 |
offset_x = 0
|
201 |
# Find the bounding box coordinates
|
|
|
110 |
img_np = img_np.astype(np.uint8)
|
111 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
112 |
|
|
|
113 |
kernel = np.ones((5, 5), np.uint8)
|
114 |
mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
|
115 |
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
116 |
for contour in contours:
|
117 |
+
cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10)
|
118 |
img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
119 |
transform = transforms.ToTensor()
|
120 |
img_tensor = transform(img_np)
|
121 |
|
|
|
126 |
def get_colored_contour(img1, img2, threshold=10):
|
127 |
diff = torch.abs(img1 - img2).float()
|
128 |
diff_gray = torch.mean(diff, dim=-1)
|
|
|
129 |
mask = diff_gray > threshold
|
130 |
|
131 |
return draw_contour(img2, mask), mask
|
|
|
150 |
def find_different_colors(img1, img2, threshold=10):
|
151 |
img1 = img1.to(torch.uint8)
|
152 |
img2 = img2.to(torch.uint8)
|
|
|
153 |
diff = torch.abs(img1 - img2).float().mean(dim=-1)
|
|
|
154 |
diff_mask = diff > threshold
|
155 |
diff_indices = torch.nonzero(diff_mask, as_tuple=True)
|
156 |
|
|
|
160 |
else:
|
161 |
sampled_diff_indices = diff_indices
|
162 |
|
|
|
163 |
diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
|
|
|
164 |
color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
|
165 |
name_counter = Counter(color_names)
|
|
|
166 |
filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
|
|
|
167 |
sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
|
168 |
if len(sorted_color_names) >= 3:
|
169 |
return "colorful"
|
|
|
174 |
# Ensure the mask is a binary mask (0s and 1s)
|
175 |
mask = mask.squeeze()
|
176 |
rows, cols = torch.where(mask > 0.5)
|
|
|
177 |
if len(rows) == 0 or len(cols) == 0:
|
178 |
return (0, 0, 0, 0)
|
179 |
height, width = mask.shape
|
180 |
if padded:
|
181 |
padded_size = max(width, height)
|
|
|
182 |
if width < height:
|
|
|
183 |
offset_x = (padded_size - width) / 2
|
184 |
offset_y = 0
|
185 |
else:
|
|
|
186 |
offset_y = (padded_size - height) / 2
|
187 |
offset_x = 0
|
188 |
# Find the bounding box coordinates
|
MagicQuill/scribble_color_edit.py
CHANGED
@@ -53,7 +53,6 @@ class ScribbleColorEditModel():
|
|
53 |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
|
54 |
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
|
55 |
self.load_models(base_model_version, dtype)
|
56 |
-
# 根据基础模型版本加载相应的 ControlNet&BrushNet 模型
|
57 |
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
|
58 |
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
|
59 |
# Grow Mask for Color Editing
|
@@ -90,9 +89,7 @@ class ScribbleColorEditModel():
|
|
90 |
bool_remove_mask_resized = (remove_mask_resized > 0.5)
|
91 |
|
92 |
if stroke_as_edge == "enable":
|
93 |
-
# 将remove_mask区域的像素变成黑色
|
94 |
lineart_output[bool_remove_mask_resized] = 0.0
|
95 |
-
# 将add_mask区域的像素变成白色
|
96 |
lineart_output[bool_add_mask_resized] = 1.0
|
97 |
else:
|
98 |
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
|
@@ -101,7 +98,7 @@ class ScribbleColorEditModel():
|
|
101 |
# BrushNet
|
102 |
model, positive, negative, latent = self.brushnet_node.model_update(
|
103 |
model=self.model,
|
104 |
-
vae=self.vae,
|
105 |
image=image,
|
106 |
mask=mask,
|
107 |
brushnet=self.brushnet,
|
|
|
53 |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
|
54 |
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
|
55 |
self.load_models(base_model_version, dtype)
|
|
|
56 |
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
|
57 |
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
|
58 |
# Grow Mask for Color Editing
|
|
|
89 |
bool_remove_mask_resized = (remove_mask_resized > 0.5)
|
90 |
|
91 |
if stroke_as_edge == "enable":
|
|
|
92 |
lineart_output[bool_remove_mask_resized] = 0.0
|
|
|
93 |
lineart_output[bool_add_mask_resized] = 1.0
|
94 |
else:
|
95 |
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
|
|
|
98 |
# BrushNet
|
99 |
model, positive, negative, latent = self.brushnet_node.model_update(
|
100 |
model=self.model,
|
101 |
+
vae=self.vae,
|
102 |
image=image,
|
103 |
mask=mask,
|
104 |
brushnet=self.brushnet,
|