LIU, Zichen commited on
Commit
ca78dbf
·
1 Parent(s): 79ecf3f
MagicQuill/magic_utils.py CHANGED
@@ -110,14 +110,12 @@ def draw_contour(img, mask):
110
  img_np = img_np.astype(np.uint8)
111
  img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
112
 
113
- # 膨胀掩码
114
  kernel = np.ones((5, 5), np.uint8)
115
  mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
116
  contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
117
  for contour in contours:
118
- cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10) # 红色线条绘制轮廓
119
  img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
120
- # 转换回tensor
121
  transform = transforms.ToTensor()
122
  img_tensor = transform(img_np)
123
 
@@ -128,7 +126,6 @@ def draw_contour(img, mask):
128
  def get_colored_contour(img1, img2, threshold=10):
129
  diff = torch.abs(img1 - img2).float()
130
  diff_gray = torch.mean(diff, dim=-1)
131
- # 阈值处理以生成二进制掩码
132
  mask = diff_gray > threshold
133
 
134
  return draw_contour(img2, mask), mask
@@ -153,9 +150,7 @@ def rgb_to_name(rgb_tuple):
153
  def find_different_colors(img1, img2, threshold=10):
154
  img1 = img1.to(torch.uint8)
155
  img2 = img2.to(torch.uint8)
156
- # 计算图像之间的绝对差异
157
  diff = torch.abs(img1 - img2).float().mean(dim=-1)
158
- # 找到大于阈值的差异区域
159
  diff_mask = diff > threshold
160
  diff_indices = torch.nonzero(diff_mask, as_tuple=True)
161
 
@@ -165,14 +160,10 @@ def find_different_colors(img1, img2, threshold=10):
165
  else:
166
  sampled_diff_indices = diff_indices
167
 
168
- # 提取不同区域的颜色
169
  diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
170
- # 将颜色值转换为颜色名称
171
  color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
172
  name_counter = Counter(color_names)
173
- # 过滤出现超过10次的颜色
174
  filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
175
- # 按出现次数从大到小排序
176
  sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
177
  if len(sorted_color_names) >= 3:
178
  return "colorful"
@@ -183,19 +174,15 @@ def get_bounding_box_from_mask(mask, padded=False):
183
  # Ensure the mask is a binary mask (0s and 1s)
184
  mask = mask.squeeze()
185
  rows, cols = torch.where(mask > 0.5)
186
- # If there are no '1's in the mask, return None or an appropriate bounding box like (0,0,0,0)
187
  if len(rows) == 0 or len(cols) == 0:
188
  return (0, 0, 0, 0)
189
  height, width = mask.shape
190
  if padded:
191
  padded_size = max(width, height)
192
- # 检查填充发生在哪个方向
193
  if width < height:
194
- # 宽度较小,填充发生在宽度上
195
  offset_x = (padded_size - width) / 2
196
  offset_y = 0
197
  else:
198
- # 高度较小,填充发生在高度上
199
  offset_y = (padded_size - height) / 2
200
  offset_x = 0
201
  # Find the bounding box coordinates
 
110
  img_np = img_np.astype(np.uint8)
111
  img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
112
 
 
113
  kernel = np.ones((5, 5), np.uint8)
114
  mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
115
  contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116
  for contour in contours:
117
+ cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10)
118
  img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
 
119
  transform = transforms.ToTensor()
120
  img_tensor = transform(img_np)
121
 
 
126
  def get_colored_contour(img1, img2, threshold=10):
127
  diff = torch.abs(img1 - img2).float()
128
  diff_gray = torch.mean(diff, dim=-1)
 
129
  mask = diff_gray > threshold
130
 
131
  return draw_contour(img2, mask), mask
 
150
  def find_different_colors(img1, img2, threshold=10):
151
  img1 = img1.to(torch.uint8)
152
  img2 = img2.to(torch.uint8)
 
153
  diff = torch.abs(img1 - img2).float().mean(dim=-1)
 
154
  diff_mask = diff > threshold
155
  diff_indices = torch.nonzero(diff_mask, as_tuple=True)
156
 
 
160
  else:
161
  sampled_diff_indices = diff_indices
162
 
 
163
  diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
 
164
  color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
165
  name_counter = Counter(color_names)
 
166
  filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
 
167
  sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
168
  if len(sorted_color_names) >= 3:
169
  return "colorful"
 
174
  # Ensure the mask is a binary mask (0s and 1s)
175
  mask = mask.squeeze()
176
  rows, cols = torch.where(mask > 0.5)
 
177
  if len(rows) == 0 or len(cols) == 0:
178
  return (0, 0, 0, 0)
179
  height, width = mask.shape
180
  if padded:
181
  padded_size = max(width, height)
 
182
  if width < height:
 
183
  offset_x = (padded_size - width) / 2
184
  offset_y = 0
185
  else:
 
186
  offset_y = (padded_size - height) / 2
187
  offset_x = 0
188
  # Find the bounding box coordinates
MagicQuill/scribble_color_edit.py CHANGED
@@ -53,7 +53,6 @@ class ScribbleColorEditModel():
53
  self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
54
  if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
55
  self.load_models(base_model_version, dtype)
56
- # 根据基础模型版本加载相应的 ControlNet&BrushNet 模型
57
  positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
58
  negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
59
  # Grow Mask for Color Editing
@@ -90,9 +89,7 @@ class ScribbleColorEditModel():
90
  bool_remove_mask_resized = (remove_mask_resized > 0.5)
91
 
92
  if stroke_as_edge == "enable":
93
- # 将remove_mask区域的像素变成黑色
94
  lineart_output[bool_remove_mask_resized] = 0.0
95
- # 将add_mask区域的像素变成白色
96
  lineart_output[bool_add_mask_resized] = 1.0
97
  else:
98
  lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
@@ -101,7 +98,7 @@ class ScribbleColorEditModel():
101
  # BrushNet
102
  model, positive, negative, latent = self.brushnet_node.model_update(
103
  model=self.model,
104
- vae=self.vae, # 需要根据实际情况提供 VAE 模型
105
  image=image,
106
  mask=mask,
107
  brushnet=self.brushnet,
 
53
  self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
54
  if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
55
  self.load_models(base_model_version, dtype)
 
56
  positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
57
  negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
58
  # Grow Mask for Color Editing
 
89
  bool_remove_mask_resized = (remove_mask_resized > 0.5)
90
 
91
  if stroke_as_edge == "enable":
 
92
  lineart_output[bool_remove_mask_resized] = 0.0
 
93
  lineart_output[bool_add_mask_resized] = 1.0
94
  else:
95
  lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
 
98
  # BrushNet
99
  model, positive, negative, latent = self.brushnet_node.model_update(
100
  model=self.model,
101
+ vae=self.vae,
102
  image=image,
103
  mask=mask,
104
  brushnet=self.brushnet,