Bugrahan Donmez commited on
Commit
97e39c4
·
1 Parent(s): 8b79fdf

Add canny detector

Browse files
Files changed (2) hide show
  1. annotator/canny/__init__.py +6 -0
  2. app.py +32 -5
annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold=100, high_threshold=200, safe=False, threshold=200):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
app.py CHANGED
@@ -8,6 +8,8 @@ import cv2
8
  from annotator.util import resize_image
9
  from annotator.hed import SOFT_HEDdetector
10
  from annotator.lineart import LineartDetector
 
 
11
  from diffusers import UNet2DConditionModel, ControlNetModel
12
  from transformers import CLIPVisionModelWithProjection
13
  from huggingface_hub import snapshot_download
@@ -18,6 +20,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
  contour_detector = SOFT_HEDdetector()
20
  lineart_detector = LineartDetector()
 
21
 
22
  base_model_path = "runwayml/stable-diffusion-v1-5"
23
  transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
@@ -48,6 +51,16 @@ contour_content_fusion_encoder = ControlNetModel.from_unet(contour_unet)
48
  contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
49
  contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
50
 
 
 
 
 
 
 
 
 
 
 
51
  lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
52
  lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
53
  lineart_transformer_block_path = transformer_block_path
@@ -66,11 +79,14 @@ def process(style_image, content_image, prompt, num_samples, image_resolution, c
66
  btns = []
67
  contour_content_images = []
68
  contour_results = []
 
 
69
  lineart_content_images = []
70
  lineart_results = []
71
 
72
  type1 = 'Contour'
73
  type2 = 'Lineart'
 
74
 
75
  if btn1 == type1 or content_image is None:
76
  style_shots = [contour_styleshot]
@@ -78,9 +94,12 @@ def process(style_image, content_image, prompt, num_samples, image_resolution, c
78
  elif btn1 == type2:
79
  style_shots = [lineart_styleshot]
80
  btns = [type2]
 
 
 
81
  elif btn1 == "Both":
82
- style_shots = [contour_styleshot, lineart_styleshot]
83
- btns = [type1, type2]
84
 
85
  ori_style_image = style_image.copy()
86
 
@@ -103,6 +122,9 @@ def process(style_image, content_image, prompt, num_samples, image_resolution, c
103
  elif btn == type2:
104
  content_image = resize_image(ori_content_image, image_resolution)
105
  content_image = lineart_detector(content_image, coarse=False)
 
 
 
106
 
107
  content_image = Image.fromarray(content_image)
108
  else:
@@ -127,12 +149,17 @@ def process(style_image, content_image, prompt, num_samples, image_resolution, c
127
  elif btn == type2:
128
  lineart_content_images = [content_image]
129
  lineart_results = g_images[0]
 
 
 
130
  if ori_content_image is None:
131
  contour_content_images = []
132
  lineart_results = []
133
  lineart_content_images = []
 
 
134
 
135
- return [contour_results, contour_content_images, lineart_results, lineart_content_images]
136
 
137
 
138
  block = gr.Blocks().queue()
@@ -147,10 +174,10 @@ with block:
147
  with gr.Column():
148
  content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
149
  btn1 = gr.Radio(
150
- choices=["Contour", "Lineart", "Both"],
151
  interactive=True,
152
  label="Preprocessor",
153
- value="Both",
154
  )
155
  gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
156
  with gr.Row():
 
8
  from annotator.util import resize_image
9
  from annotator.hed import SOFT_HEDdetector
10
  from annotator.lineart import LineartDetector
11
+ from annotator.lineart import LineartDetector
12
+ from annotator.canny import CannyDetector
13
  from diffusers import UNet2DConditionModel, ControlNetModel
14
  from transformers import CLIPVisionModelWithProjection
15
  from huggingface_hub import snapshot_download
 
20
 
21
  contour_detector = SOFT_HEDdetector()
22
  lineart_detector = LineartDetector()
23
+ canny_detector = CannyDetector()
24
 
25
  base_model_path = "runwayml/stable-diffusion-v1-5"
26
  transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
 
51
  contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
52
  contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
53
 
54
+ # weights for ip-adapter and our content-fusion encoder
55
+ canny_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
56
+ canny_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
57
+ canny_transformer_block_path = transformer_block_path
58
+ canny_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
59
+ canny_content_fusion_encoder = ControlNetModel.from_unet(canny_unet)
60
+
61
+ canny_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=canny_content_fusion_encoder)
62
+ canny_styleshot = StyleShot(device, canny_pipe, canny_ip_ckpt, canny_style_aware_encoder_path, canny_transformer_block_path)
63
+
64
  lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
65
  lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
66
  lineart_transformer_block_path = transformer_block_path
 
79
  btns = []
80
  contour_content_images = []
81
  contour_results = []
82
+ canny_content_images = []
83
+ canny_results = []
84
  lineart_content_images = []
85
  lineart_results = []
86
 
87
  type1 = 'Contour'
88
  type2 = 'Lineart'
89
+ type3 = 'Canny'
90
 
91
  if btn1 == type1 or content_image is None:
92
  style_shots = [contour_styleshot]
 
94
  elif btn1 == type2:
95
  style_shots = [lineart_styleshot]
96
  btns = [type2]
97
+ elif btn1 == type3:
98
+ style_shots = [canny_styleshot]
99
+ btns = [type3]
100
  elif btn1 == "Both":
101
+ style_shots = [contour_styleshot, lineart_styleshot, canny_styleshot]
102
+ btns = [type1, type2, type3]
103
 
104
  ori_style_image = style_image.copy()
105
 
 
122
  elif btn == type2:
123
  content_image = resize_image(ori_content_image, image_resolution)
124
  content_image = lineart_detector(content_image, coarse=False)
125
+ elif btn == type3:
126
+ content_image = resize_image(ori_content_image, image_resolution)
127
+ content_image = canny_detector(content_image)
128
 
129
  content_image = Image.fromarray(content_image)
130
  else:
 
149
  elif btn == type2:
150
  lineart_content_images = [content_image]
151
  lineart_results = g_images[0]
152
+ elif btn == type3:
153
+ canny_content_images = [content_image]
154
+ canny_results = g_images[0]
155
  if ori_content_image is None:
156
  contour_content_images = []
157
  lineart_results = []
158
  lineart_content_images = []
159
+ canny_results = []
160
+ canny_content_images = []
161
 
162
+ return [contour_results, contour_content_images, lineart_results, lineart_content_images, canny_results, canny_content_images]
163
 
164
 
165
  block = gr.Blocks().queue()
 
174
  with gr.Column():
175
  content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
176
  btn1 = gr.Radio(
177
+ choices=["Contour", "Lineart", "Canny", "All"],
178
  interactive=True,
179
  label="Preprocessor",
180
+ value="All",
181
  )
182
  gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
183
  with gr.Row():