shgao commited on
Commit
bc1676d
β€’
1 Parent(s): 0d3a58d

update tile

Browse files
Files changed (5) hide show
  1. app.py +2 -1
  2. sam2edit.py +5 -8
  3. sam2edit_beauty.py +5 -8
  4. sam2edit_handsome.py +4 -7
  5. sam2edit_lora.py +104 -48
app.py CHANGED
@@ -41,7 +41,8 @@ with gr.Blocks() as demo:
41
  lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
42
  sam_generator=sam_generator,
43
  blip_processor=blip_processor,
44
- blip_model=blip_model
 
45
  )
46
  create_demo_beauty(model.process)
47
  with gr.TabItem(' πŸ‘¨β€πŸŒΎHandsome Edit/Generation'):
 
41
  lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
42
  sam_generator=sam_generator,
43
  blip_processor=blip_processor,
44
+ blip_model=blip_model,
45
+ lora_weight=0.5,
46
  )
47
  create_demo_beauty(model.process)
48
  with gr.TabItem(' πŸ‘¨β€πŸŒΎHandsome Edit/Generation'):
sam2edit.py CHANGED
@@ -16,7 +16,7 @@ def create_demo(process):
16
  with block as demo:
17
  with gr.Row():
18
  gr.Markdown(
19
- "## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
20
  with gr.Row():
21
  with gr.Column():
22
  source_image = gr.Image(
@@ -38,12 +38,9 @@ def create_demo(process):
38
  label="Images", minimum=1, maximum=12, value=2, step=1)
39
  seed = gr.Slider(label="Seed", minimum=-1,
40
  maximum=2147483647, step=1, randomize=True)
 
 
41
  with gr.Accordion("Advanced options", open=False):
42
- condition_model = gr.Dropdown(choices=list(config_dict.keys()),
43
- value=list(
44
- config_dict.keys())[1],
45
- label='Model',
46
- multiselect=False)
47
  mask_image = gr.Image(
48
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
49
  image_resolution = gr.Slider(
@@ -63,8 +60,8 @@ def create_demo(process):
63
  result_gallery = gr.Gallery(
64
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
65
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
66
- ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
67
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
68
  run_button.click(fn=process, inputs=ips, outputs=[
69
  result_gallery, result_text])
70
  # with gr.Row():
 
16
  with block as demo:
17
  with gr.Row():
18
  gr.Markdown(
19
+ "## EditAnything https://github.com/sail-sg/EditAnything ")
20
  with gr.Row():
21
  with gr.Column():
22
  source_image = gr.Image(
 
38
  label="Images", minimum=1, maximum=12, value=2, step=1)
39
  seed = gr.Slider(label="Seed", minimum=-1,
40
  maximum=2147483647, step=1, randomize=True)
41
+ enable_tile = gr.Checkbox(
42
+ label='Tile refinement for high resolution generation.', value=True)
43
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
44
  mask_image = gr.Image(
45
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
46
  image_resolution = gr.Slider(
 
60
  result_gallery = gr.Gallery(
61
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
62
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
63
+ ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
64
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
65
  run_button.click(fn=process, inputs=ips, outputs=[
66
  result_gallery, result_text])
67
  # with gr.Row():
sam2edit_beauty.py CHANGED
@@ -49,12 +49,9 @@ def create_demo(process):
49
  label="Images", minimum=1, maximum=12, value=2, step=1)
50
  seed = gr.Slider(label="Seed", minimum=-1,
51
  maximum=2147483647, step=1, randomize=True)
 
 
52
  with gr.Accordion("Advanced options", open=False):
53
- condition_model = gr.Dropdown(choices=list(config_dict.keys()),
54
- value=list(
55
- config_dict.keys())[0],
56
- label='Model',
57
- multiselect=False)
58
  mask_image = gr.Image(
59
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
60
  image_resolution = gr.Slider(
@@ -74,8 +71,8 @@ def create_demo(process):
74
  result_gallery = gr.Gallery(
75
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
76
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
77
- ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
78
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
79
  run_button.click(fn=process, inputs=ips, outputs=[
80
  result_gallery, result_text])
81
  with gr.Row():
@@ -90,6 +87,6 @@ def create_demo(process):
90
 
91
  if __name__ == '__main__':
92
  model = EditAnythingLoraModel(base_model_path='../chilloutmix_NiPrunedFp32Fix',
93
- lora_model_path='../40806/mix4', use_blip=True)
94
  demo = create_demo(model.process)
95
  demo.queue().launch(server_name='0.0.0.0')
 
49
  label="Images", minimum=1, maximum=12, value=2, step=1)
50
  seed = gr.Slider(label="Seed", minimum=-1,
51
  maximum=2147483647, step=1, randomize=True)
52
+ enable_tile = gr.Checkbox(
53
+ label='Tile refinement for high resolution generation.', value=True)
54
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
55
  mask_image = gr.Image(
56
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
57
  image_resolution = gr.Slider(
 
71
  result_gallery = gr.Gallery(
72
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
73
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
74
+ ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
75
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
76
  run_button.click(fn=process, inputs=ips, outputs=[
77
  result_gallery, result_text])
78
  with gr.Row():
 
87
 
88
  if __name__ == '__main__':
89
  model = EditAnythingLoraModel(base_model_path='../chilloutmix_NiPrunedFp32Fix',
90
+ lora_model_path='../40806/mix4', use_blip=True, lora_weight=0.5)
91
  demo = create_demo(model.process)
92
  demo.queue().launch(server_name='0.0.0.0')
sam2edit_handsome.py CHANGED
@@ -43,12 +43,9 @@ def create_demo(process):
43
  label="Images", minimum=1, maximum=12, value=2, step=1)
44
  seed = gr.Slider(label="Seed", minimum=-1,
45
  maximum=2147483647, step=1, randomize=True)
 
 
46
  with gr.Accordion("Advanced options", open=False):
47
- condition_model = gr.Dropdown(choices=list(config_dict.keys()),
48
- value=list(
49
- config_dict.keys())[0],
50
- label='Model',
51
- multiselect=False)
52
  mask_image = gr.Image(
53
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
54
  image_resolution = gr.Slider(
@@ -68,8 +65,8 @@ def create_demo(process):
68
  result_gallery = gr.Gallery(
69
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
70
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
71
- ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
72
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
73
  run_button.click(fn=process, inputs=ips, outputs=[
74
  result_gallery, result_text])
75
  with gr.Row():
 
43
  label="Images", minimum=1, maximum=12, value=2, step=1)
44
  seed = gr.Slider(label="Seed", minimum=-1,
45
  maximum=2147483647, step=1, randomize=True)
46
+ enable_tile = gr.Checkbox(
47
+ label='Tile refinement for high resolution generation.', value=True)
48
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
49
  mask_image = gr.Image(
50
  source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
51
  image_resolution = gr.Slider(
 
65
  result_gallery = gr.Gallery(
66
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
67
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
68
+ ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
69
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
70
  run_button.click(fn=process, inputs=ips, outputs=[
71
  result_gallery, result_text])
72
  with gr.Row():
sam2edit_lora.py CHANGED
@@ -26,6 +26,8 @@ from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetI
26
  # need the latest transformers
27
  # pip install git+https://github.com/huggingface/transformers.git
28
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
 
 
29
 
30
  # Segment-Anything init.
31
  # pip install git+https://github.com/facebookresearch/segment-anything.git
@@ -110,6 +112,7 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
110
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
111
 
112
 
 
113
  def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
114
  LORA_PREFIX_UNET = "lora_unet"
115
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -238,34 +241,51 @@ def make_inpaint_condition(image, image_mask):
238
  image = torch.from_numpy(image)
239
  return image
240
 
 
 
 
 
 
 
 
 
 
241
 
242
- def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True):
243
- if generation_only and extra_inpaint:
244
- controlnet = ControlNetModel.from_pretrained(
245
- controlnet_path, torch_dtype=torch.float16)
246
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
247
  base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
248
  )
249
- elif extra_inpaint:
250
- print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
251
- controlnet = [
252
- ControlNetModel.from_pretrained(
253
- controlnet_path, torch_dtype=torch.float16),
254
- ControlNetModel.from_pretrained(
255
- 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16), # inpainting controlnet
256
- ]
257
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
258
  base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
259
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  else:
261
- controlnet = ControlNetModel.from_pretrained(
262
- controlnet_path, torch_dtype=torch.float16)
263
- pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
264
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
265
  )
266
  if lora_model_path is not None:
267
  pipe = load_lora_weights(
268
- pipe, [lora_model_path], 1.0, 'cpu', torch.float32)
269
  # speed up diffusion process with faster scheduler and memory optimization
270
  pipe.scheduler = UniPCMultistepScheduler.from_config(
271
  pipe.scheduler.config)
@@ -276,6 +296,7 @@ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, g
276
  return pipe
277
 
278
 
 
279
  def show_anns(anns):
280
  if len(anns) == 0:
281
  return
@@ -310,8 +331,9 @@ class EditAnythingLoraModel:
310
  blip_model=None,
311
  sam_generator=None,
312
  controlmodel_name='LAION Pretrained(v0-4)-SD15',
313
- # used when the base model is not an inpainting model.
314
- extra_inpaint=True,
 
315
  ):
316
  self.device = device
317
  self.use_blip = use_blip
@@ -323,7 +345,7 @@ class EditAnythingLoraModel:
323
  self.defalut_enable_all_generate = False
324
  self.extra_inpaint = extra_inpaint
325
  self.pipe = obtain_generation_model(
326
- base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint)
327
 
328
  # Segment-Anything init.
329
  if sam_generator is not None:
@@ -343,6 +365,12 @@ class EditAnythingLoraModel:
343
  else:
344
  self.blip_model = init_blip_model()
345
 
 
 
 
 
 
 
346
  def get_blip2_text(self, image):
347
  inputs = self.blip_processor(image, return_tensors="pt").to(
348
  self.device, torch.float16)
@@ -357,13 +385,23 @@ class EditAnythingLoraModel:
357
  return full_img, res
358
 
359
  @torch.inference_mode()
360
- def process(self, condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
361
-
 
 
 
 
 
 
 
 
 
362
  input_image = source_image["image"]
363
  if mask_image is None:
364
  if enable_all_generate != self.defalut_enable_all_generate:
365
  self.pipe = obtain_generation_model(
366
- self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
 
367
  self.defalut_enable_all_generate = enable_all_generate
368
  if enable_all_generate:
369
  print("source_image",
@@ -372,13 +410,13 @@ class EditAnythingLoraModel:
372
  (input_image.shape[0], input_image.shape[1], 3))*255
373
  else:
374
  mask_image = source_image["mask"]
375
- if self.default_controlnet_path != config_dict[condition_model]:
376
- print("To Use:", config_dict[condition_model],
377
  "Current:", self.default_controlnet_path)
378
- print("Change condition model to:", config_dict[condition_model])
379
  self.pipe = obtain_generation_model(
380
- self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
381
- self.default_controlnet_path = config_dict[condition_model]
382
  torch.cuda.empty_cache()
383
 
384
  with torch.no_grad():
@@ -411,11 +449,9 @@ class EditAnythingLoraModel:
411
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
412
 
413
  mask_image = HWC3(mask_image.astype(np.uint8))
414
- mask_image = cv2.resize(
415
  mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
416
- if self.extra_inpaint:
417
- inpaint_image = make_inpaint_condition(img, mask_image)
418
- mask_image = Image.fromarray(mask_image)
419
 
420
  if seed == -1:
421
  seed = random.randint(0, 65535)
@@ -429,7 +465,6 @@ class EditAnythingLoraModel:
429
  negative_prompt_embeds = torch.cat(
430
  [negative_prompt_embeds] * num_samples, dim=0)
431
  if enable_all_generate and self.extra_inpaint:
432
- print(control.shape, control_scale)
433
  self.pipe.safety_checker = lambda images, clip_input: (
434
  images, False)
435
  x_samples = self.pipe(
@@ -439,10 +474,19 @@ class EditAnythingLoraModel:
439
  generator=generator,
440
  height=H,
441
  width=W,
442
- image=control.type(torch.float16),
443
- controlnet_conditioning_scale=float(control_scale),
444
  ).images
445
- elif self.extra_inpaint:
 
 
 
 
 
 
 
 
 
446
  x_samples = self.pipe(
447
  image=img,
448
  mask_image=mask_image,
@@ -450,27 +494,39 @@ class EditAnythingLoraModel:
450
  num_images_per_prompt=num_samples,
451
  num_inference_steps=ddim_steps,
452
  generator=generator,
453
- controlnet_conditioning_image=[control.type(
454
- torch.float16), inpaint_image.type(torch.float16)],
455
  height=H,
456
  width=W,
457
- controlnet_conditioning_scale=(float(control_scale), 1.0),
458
  ).images
459
- else:
460
- x_samples = self.pipe(
461
- image=img,
462
- mask_image=mask_image,
 
 
 
 
 
 
 
 
463
  prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
464
  num_images_per_prompt=num_samples,
465
  num_inference_steps=ddim_steps,
466
  generator=generator,
467
- controlnet_conditioning_image=control.type(torch.float16),
468
- height=H,
469
- width=W,
470
- controlnet_conditioning_scale=float(control_scale),
471
  ).images
472
 
473
- results = [x_samples[i] for i in range(num_samples)]
 
 
 
 
 
474
  return [full_segmask, mask_image] + results, prompt
475
 
476
  def download_image(url):
 
26
  # need the latest transformers
27
  # pip install git+https://github.com/huggingface/transformers.git
28
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
29
+ from diffusers import ControlNetModel, DiffusionPipeline
30
+ import PIL.Image
31
 
32
  # Segment-Anything init.
33
  # pip install git+https://github.com/facebookresearch/segment-anything.git
 
112
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
113
 
114
 
115
+
116
  def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
117
  LORA_PREFIX_UNET = "lora_unet"
118
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
 
241
  image = torch.from_numpy(image)
242
  return image
243
 
244
+ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
245
+ controlnet = []
246
+ controlnet.append(ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)) # sam control
247
+ if (not generation_only) and extra_inpaint: # inpainting control
248
+ print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
249
+ controlnet.append(
250
+ ControlNetModel.from_pretrained(
251
+ 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16) # inpainting controlnet
252
+ )
253
 
254
+ if generation_only:
 
 
 
255
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
256
  base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
257
  )
258
+ else:
 
 
 
 
 
 
 
259
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
260
  base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
261
  )
262
+ if lora_model_path is not None:
263
+ pipe = load_lora_weights(
264
+ pipe, [lora_model_path], lora_weight, 'cpu', torch.float32)
265
+ # speed up diffusion process with faster scheduler and memory optimization
266
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
267
+ pipe.scheduler.config)
268
+ # remove following line if xformers is not installed
269
+ pipe.enable_xformers_memory_efficient_attention()
270
+
271
+ pipe.enable_model_cpu_offload()
272
+ return pipe
273
+
274
+ def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
275
+ controlnet = ControlNetModel.from_pretrained(
276
+ 'lllyasviel/control_v11f1e_sd15_tile', torch_dtype=torch.float16) # tile controlnet
277
+ if base_model_path=='runwayml/stable-diffusion-v1-5' or base_model_path=='stabilityai/stable-diffusion-2-inpainting':
278
+ print("base_model_path", base_model_path)
279
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
280
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
281
+ )
282
  else:
283
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
284
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
 
 
285
  )
286
  if lora_model_path is not None:
287
  pipe = load_lora_weights(
288
+ pipe, [lora_model_path], lora_weight, 'cpu', torch.float32)
289
  # speed up diffusion process with faster scheduler and memory optimization
290
  pipe.scheduler = UniPCMultistepScheduler.from_config(
291
  pipe.scheduler.config)
 
296
  return pipe
297
 
298
 
299
+
300
  def show_anns(anns):
301
  if len(anns) == 0:
302
  return
 
331
  blip_model=None,
332
  sam_generator=None,
333
  controlmodel_name='LAION Pretrained(v0-4)-SD15',
334
+ extra_inpaint=True, # used when the base model is not an inpainting model.
335
+ tile_model=None,
336
+ lora_weight=1.0,
337
  ):
338
  self.device = device
339
  self.use_blip = use_blip
 
345
  self.defalut_enable_all_generate = False
346
  self.extra_inpaint = extra_inpaint
347
  self.pipe = obtain_generation_model(
348
+ base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
349
 
350
  # Segment-Anything init.
351
  if sam_generator is not None:
 
365
  else:
366
  self.blip_model = init_blip_model()
367
 
368
+ # tile model init.
369
+ if tile_model is not None:
370
+ self.tile_pipe = tile_model
371
+ else:
372
+ self.tile_pipe = obtain_tile_model(base_model_path, lora_model_path, lora_weight=lora_weight)
373
+
374
  def get_blip2_text(self, image):
375
  inputs = self.blip_processor(image, return_tensors="pt").to(
376
  self.device, torch.float16)
 
385
  return full_img, res
386
 
387
  @torch.inference_mode()
388
+ def process(self, source_image, enable_all_generate, mask_image,
389
+ control_scale,
390
+ enable_auto_prompt, prompt, a_prompt, n_prompt,
391
+ num_samples, image_resolution, detect_resolution,
392
+ ddim_steps, guess_mode, strength, scale, seed, eta,
393
+ enable_tile=True, condition_model=None):
394
+
395
+ if condition_model is None:
396
+ this_controlnet_path = self.default_controlnet_path
397
+ else:
398
+ this_controlnet_path = config_dict[condition_model]
399
  input_image = source_image["image"]
400
  if mask_image is None:
401
  if enable_all_generate != self.defalut_enable_all_generate:
402
  self.pipe = obtain_generation_model(
403
+ self.base_model_path, self.lora_model_path, this_controlnet_path, enable_all_generate, self.extra_inpaint)
404
+
405
  self.defalut_enable_all_generate = enable_all_generate
406
  if enable_all_generate:
407
  print("source_image",
 
410
  (input_image.shape[0], input_image.shape[1], 3))*255
411
  else:
412
  mask_image = source_image["mask"]
413
+ if self.default_controlnet_path != this_controlnet_path:
414
+ print("To Use:", this_controlnet_path,
415
  "Current:", self.default_controlnet_path)
416
+ print("Change condition model to:", this_controlnet_path)
417
  self.pipe = obtain_generation_model(
418
+ self.base_model_path, self.lora_model_path, this_controlnet_path, enable_all_generate, self.extra_inpaint)
419
+ self.default_controlnet_path = this_controlnet_path
420
  torch.cuda.empty_cache()
421
 
422
  with torch.no_grad():
 
449
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
450
 
451
  mask_image = HWC3(mask_image.astype(np.uint8))
452
+ mask_image_tmp = cv2.resize(
453
  mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
454
+ mask_image = Image.fromarray(mask_image_tmp)
 
 
455
 
456
  if seed == -1:
457
  seed = random.randint(0, 65535)
 
465
  negative_prompt_embeds = torch.cat(
466
  [negative_prompt_embeds] * num_samples, dim=0)
467
  if enable_all_generate and self.extra_inpaint:
 
468
  self.pipe.safety_checker = lambda images, clip_input: (
469
  images, False)
470
  x_samples = self.pipe(
 
474
  generator=generator,
475
  height=H,
476
  width=W,
477
+ image=[control.type(torch.float16)],
478
+ controlnet_conditioning_scale=[float(control_scale)],
479
  ).images
480
+ else:
481
+ multi_condition_image = []
482
+ multi_condition_scale = []
483
+ multi_condition_image.append(control.type(torch.float16))
484
+ multi_condition_scale.append(float(control_scale))
485
+ if self.extra_inpaint:
486
+ inpaint_image = make_inpaint_condition(img, mask_image_tmp)
487
+ print(inpaint_image.shape)
488
+ multi_condition_image.append(inpaint_image.type(torch.float16))
489
+ multi_condition_scale.append(1.0)
490
  x_samples = self.pipe(
491
  image=img,
492
  mask_image=mask_image,
 
494
  num_images_per_prompt=num_samples,
495
  num_inference_steps=ddim_steps,
496
  generator=generator,
497
+ controlnet_conditioning_image=multi_condition_image,
 
498
  height=H,
499
  width=W,
500
+ controlnet_conditioning_scale=multi_condition_scale,
501
  ).images
502
+ results = [x_samples[i] for i in range(num_samples)]
503
+
504
+ if True:
505
+ img_tile = [PIL.Image.fromarray(resize_image(np.array(x_samples[i]), 1024)) for i in range(num_samples)]
506
+ # for each in img_tile:
507
+ # print("tile",each.size)
508
+ prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
509
+ self.tile_pipe, postive_prompt, negative_prompt, "cuda")
510
+ prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
511
+ negative_prompt_embeds = torch.cat(
512
+ [negative_prompt_embeds] * num_samples, dim=0)
513
+ x_samples_tile = self.tile_pipe(
514
  prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
515
  num_images_per_prompt=num_samples,
516
  num_inference_steps=ddim_steps,
517
  generator=generator,
518
+ height=img_tile[0].size[1],
519
+ width=img_tile[0].size[0],
520
+ image=img_tile,
521
+ controlnet_conditioning_scale=1.0,
522
  ).images
523
 
524
+ results_tile = [x_samples_tile[i] for i in range(num_samples)]
525
+ results = results_tile + results
526
+
527
+
528
+
529
+
530
  return [full_segmask, mask_image] + results, prompt
531
 
532
  def download_image(url):