chaojiemao commited on
Commit
06f5716
·
verified ·
1 Parent(s): f2838d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -184
app.py CHANGED
@@ -17,15 +17,10 @@ from scepter.modules.transform.io import pillow_convert
17
  from scepter.modules.utils.config import Config
18
  from scepter.modules.utils.distribute import we
19
  from scepter.modules.utils.file_system import FS
20
-
21
- from inference.ace_plus_diffusers import ACEPlusDiffuserInference
22
  from inference.utils import edit_preprocess
23
- from examples.examples import all_examples
24
-
25
 
26
- inference_dict = {
27
- "ACE_DIFFUSER_PLUS": ACEPlusDiffuserInference
28
- }
29
 
30
  fs_list = [
31
  Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
@@ -38,15 +33,10 @@ for one_fs in fs_list:
38
  FS.init_fs_client(one_fs)
39
 
40
  os.environ["FLUX_FILL_PATH"]="hf://black-forest-labs/FLUX.1-Fill-dev"
41
- os.environ["PORTRAIT_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
42
- os.environ["SUBJECT_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
43
- os.environ["LOCAL_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
44
 
45
  FS.get_dir_to_local_dir(os.environ["FLUX_FILL_PATH"])
46
- FS.get_from(os.environ["PORTRAIT_MODEL_PATH"])
47
- FS.get_from(os.environ["SUBJECT_MODEL_PATH"])
48
- FS.get_from(os.environ["LOCAL_MODEL_PATH"])
49
-
50
 
51
  csv.field_size_limit(sys.maxsize)
52
  refresh_sty = '\U0001f504' # 🔄
@@ -60,51 +50,39 @@ lock = threading.Lock()
60
  class DemoUI(object):
61
  #@spaces.GPU(duration=60)
62
  def __init__(self,
63
- infer_dir = "./config",
64
- model_list='./models/model_zoo.yaml'
65
  ):
66
- self.model_yamls = glob.glob(os.path.join(infer_dir,
67
- '*.yaml'))
68
  self.model_choices = dict()
69
  self.default_model_name = ''
 
 
 
70
  for i in self.model_yamls:
71
  model_cfg = Config(load=True, cfg_file=i)
72
- model_name = model_cfg.NAME
73
  if model_cfg.IS_DEFAULT: self.default_model_name = model_name
74
  self.model_choices[model_name] = model_cfg
 
 
 
 
 
75
  print('Models: ', self.model_choices.keys())
76
  assert len(self.model_choices) > 0
77
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
78
  self.model_name = self.default_model_name
79
  pipe_cfg = self.model_choices[self.default_model_name]
80
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
81
- self.pipe = inference_dict[infer_name]()
82
- self.pipe.init_from_cfg(pipe_cfg)
83
-
84
- # choose different model
85
- self.task_model_cfg = Config(load=True, cfg_file=model_list)
86
- self.task_model = {}
87
- self.task_model_list = []
88
- self.edit_type_dict = {"repainting": None}
89
- self.edit_type_list = ["repainting"]
90
- for task_name, task_model in self.task_model_cfg.MODEL.items():
91
- self.task_model[task_name.lower()] = task_model
92
- self.task_model_list.append(task_name.lower())
93
- for preprocessor in task_model.get("PREPROCESSOR", []):
94
- if preprocessor["TYPE"] in self.edit_type_dict:
95
- continue
96
- preprocessor["REPAINTING_SCALE"] = task_model.get("REPAINTING_SCALE", 1.0)
97
- self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
98
- self.max_msgs = 20
99
  # reformat examples
100
  self.all_examples = [
101
  [
102
- one_example["task_type"], one_example["edit_type"], one_example["instruction"],
103
- one_example["input_reference_image"], one_example["input_image"],
104
- one_example["input_mask"], one_example["output_h"],
105
- one_example["output_w"], one_example["seed"]
106
- ]
107
- for one_example in all_examples
108
  ]
109
 
110
  def construct_edit_image(self, edit_image, edit_mask):
@@ -127,9 +105,6 @@ class DemoUI(object):
127
  else:
128
  return None
129
 
130
-
131
-
132
-
133
  def create_ui(self):
134
  with gr.Row(equal_height=True, visible=True):
135
  with gr.Column(scale=2):
@@ -146,40 +121,102 @@ class DemoUI(object):
146
  height=600,
147
  interactive=False,
148
  type='pil',
149
- elem_id='preprocess_image'
 
150
  )
151
 
152
  self.edit_preprocess_mask_preview = gr.Image(
153
  height=600,
154
  interactive=False,
155
  type='pil',
156
- elem_id='preprocess_image_mask'
 
 
 
 
 
 
 
 
 
157
  )
158
  with gr.Row():
159
  instruction = """
160
  **Instruction**:
161
- 1. Please choose the Task Type based on the scenario of the generation task. We provide three types of generation capabilities: Portrait ID Preservation Generation(portrait),
162
- Object ID Preservation Generation(subject), and Local Controlled Generation(local editing), which can be selected from the task dropdown menu.
163
- 2. When uploading images in the Reference Image section, the generated image will reference the ID information of that image. Please ensure that the ID information is clear.
164
- In the Edit Image section, the uploaded image will maintain its structural and content information, and you must draw a mask area to specify the region to be regenerated.
165
- 3. When the task type is local editing, there are various editing types to choose from. Users can select different information preserving dimensions, such as edge information,
166
- color information, and more. The pre-processing information can be viewed in the 'related input image' tab.
167
- 4. More details can be found in [page](https://ali-vilab.github.io/ACE_plus_page).
168
  """
169
  self.instruction = gr.Markdown(value=instruction)
 
 
 
 
 
 
 
 
 
170
  with gr.Row():
171
  self.model_name_dd = gr.Dropdown(
172
  choices=self.model_choices,
173
  value=self.default_model_name,
174
  label='Model Version')
175
- self.task_type = gr.Dropdown(choices=self.task_model_list,
176
  interactive=True,
177
- value=self.task_model_list[0],
178
- label='Task Type')
179
- self.edit_type = gr.Dropdown(choices=self.edit_type_list,
180
- interactive=True,
181
- value=self.edit_type_list[0],
182
  label='Edit Type')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with gr.Row():
184
  self.generation_info_preview = gr.Markdown(
185
  label='System Log.',
@@ -192,11 +229,11 @@ class DemoUI(object):
192
  placeholder='Input "@" find history of image',
193
  label='Instruction',
194
  container=False,
195
- lines = 1)
196
  with gr.Column(scale=2, min_width=100):
197
  with gr.Row():
198
  with gr.Column(scale=1, min_width=100):
199
- self.chat_btn = gr.Button(value='Generate', variant = "primary")
200
 
201
  with gr.Accordion(label='Advance', open=True):
202
  with gr.Row(visible=True):
@@ -223,45 +260,8 @@ class DemoUI(object):
223
  format="png"
224
  )
225
 
226
- with gr.Row():
227
- self.step = gr.Slider(minimum=1,
228
- maximum=1000,
229
- value=self.pipe.input.get("sample_steps", 20),
230
- visible=self.pipe.input.get("sample_steps", None) is not None,
231
- label='Sample Step')
232
- self.cfg_scale = gr.Slider(
233
- minimum=1.0,
234
- maximum=100.0,
235
- value=self.pipe.input.get("guide_scale", 4.5),
236
- visible=self.pipe.input.get("guide_scale", None) is not None,
237
- label='Guidance Scale')
238
- self.seed = gr.Slider(minimum=-1,
239
- maximum=10000000,
240
- value=-1,
241
- label='Seed')
242
- self.output_height = gr.Slider(
243
- minimum=256,
244
- maximum=1440,
245
- value=self.pipe.input.get("output_height", 1024),
246
- visible=self.pipe.input.get("output_height", None) is not None,
247
- label='Output Height')
248
- self.output_width = gr.Slider(
249
- minimum=256,
250
- maximum=1440,
251
- value=self.pipe.input.get("output_width", 1024),
252
- visible=self.pipe.input.get("output_width", None) is not None,
253
- label='Output Width')
254
-
255
- self.repainting_scale = gr.Slider(
256
- minimum=0.0,
257
- maximum=1.0,
258
- value=self.pipe.input.get("repainting_scale", 1.0),
259
- visible=True,
260
- label='Repainting Scale')
261
- with gr.Row():
262
- self.eg = gr.Column(visible=True)
263
-
264
-
265
 
266
  def set_callbacks(self, *args, **kwargs):
267
  ########################################
@@ -276,25 +276,23 @@ class DemoUI(object):
276
  torch.cuda.empty_cache()
277
  torch.cuda.ipc_collect()
278
  pipe_cfg = self.model_choices[model_name]
279
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
280
- self.pipe = inference_dict[infer_name]()
281
- self.pipe.init_from_cfg(pipe_cfg)
282
  self.model_name = model_name
283
  lock.release()
284
 
285
  return (model_name, gr.update(),
286
  gr.Slider(
287
- value=self.pipe.input.get("sample_steps", 20),
288
- visible=self.pipe.input.get("sample_steps", None) is not None),
289
  gr.Slider(
290
  value=self.pipe.input.get("guide_scale", 4.5),
291
  visible=self.pipe.input.get("guide_scale", None) is not None),
292
  gr.Slider(
293
- value=self.pipe.input.get("output_height", 1024),
294
- visible=self.pipe.input.get("output_height", None) is not None),
295
  gr.Slider(
296
- value=self.pipe.input.get("output_width", 1024),
297
- visible=self.pipe.input.get("output_width", None) is not None),
298
  gr.Slider(value=self.pipe.input.get("repainting_scale", 1.0))
299
  )
300
 
@@ -309,31 +307,21 @@ class DemoUI(object):
309
  self.output_width,
310
  self.repainting_scale])
311
 
312
- def change_task_type(task_type):
313
- task_info = self.task_model[task_type]
314
- edit_type_list = [self.edit_type_list[0]]
315
- for preprocessor in task_info.get("PREPROCESSOR", []):
316
- preprocessor["REPAINTING_SCALE"] = task_info.get("REPAINTING_SCALE", 1.0)
317
- self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
318
- edit_type_list.append(preprocessor["TYPE"])
319
-
320
- return gr.update(choices=edit_type_list, value=edit_type_list[0])
321
-
322
- self.task_type.change(change_task_type, inputs=[self.task_type], outputs=[self.edit_type])
323
-
324
  def change_edit_type(edit_type):
325
  edit_info = self.edit_type_dict[edit_type]
326
  edit_info = edit_info or {}
327
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
328
- if edit_type == self.edit_type_list[0]:
329
- return gr.Slider(value=1.0)
330
- else:
331
- return gr.Slider(
332
- value=repainting_scale)
333
 
334
  self.edit_type.change(change_edit_type, inputs=[self.edit_type], outputs=[self.repainting_scale])
335
 
336
- def preprocess_input(ref_image, edit_image_dict, preprocess = None):
 
 
 
 
 
 
337
  err_msg = ""
338
  is_suc = True
339
  if ref_image is not None:
@@ -349,8 +337,9 @@ class DemoUI(object):
349
  edit_image = None
350
  edit_mask = None
351
  elif np.sum(np.array(edit_mask)) < 1:
352
- err_msg = "You must draw the repainting area for the edited image."
353
- return None, None, None, False, err_msg
 
354
  else:
355
  edit_image = pillow_convert(edit_image, "RGB")
356
  edit_mask = Image.fromarray(edit_mask).convert('L')
@@ -358,43 +347,38 @@ class DemoUI(object):
358
  err_msg = "Please provide the reference image or edited image."
359
  return None, None, None, False, err_msg
360
  return edit_image, edit_mask, ref_image, is_suc, err_msg
 
361
  @spaces.GPU(duration=80)
362
  def run_chat(
363
- prompt,
364
- ref_image,
365
- edit_image,
366
- task_type,
367
- edit_type,
368
- cfg_scale,
369
- step,
370
- seed,
371
- output_h,
372
- output_w,
373
- repainting_scale,
374
- progress=gr.Progress(track_tqdm=True)
 
 
375
  ):
376
- print(prompt)
377
- model_path = self.task_model[task_type]["MODEL_PATH"]
378
  edit_info = self.edit_type_dict[edit_type]
379
-
380
- if task_type in ["portrait", "subject"] and ref_image is None:
381
- err_msg = "<mark>Please provide the reference image.</mark>"
382
- return (gr.Image(), gr.Column(visible=True),
383
- gr.Image(),
384
- gr.Image(),
385
- gr.Text(value=err_msg))
386
-
387
  pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
 
388
  if not is_suc:
389
  err_msg = f"<mark>{err_msg}</mark>"
390
  return (gr.Image(), gr.Column(visible=True),
 
391
  gr.Image(),
392
  gr.Image(),
393
  gr.Text(value=err_msg))
394
- pre_edit_image = edit_preprocess(edit_info, we.device_id, pre_edit_image, pre_edit_mask)
395
  # edit_image["background"] = pre_edit_image
396
  st = time.time()
397
- image, seed = self.pipe(
398
  reference_image=pre_ref_image,
399
  edit_image=pre_edit_image,
400
  edit_mask=pre_edit_mask,
@@ -406,32 +390,43 @@ class DemoUI(object):
406
  guide_scale=cfg_scale,
407
  seed=seed,
408
  repainting_scale=repainting_scale,
409
- lora_path = model_path
 
 
410
  )
411
  et = time.time()
412
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
413
 
 
 
 
414
  return (gr.Image(value=image), gr.Column(visible=True),
415
- gr.Image(value=pre_edit_image if pre_edit_image is not None else pre_ref_image),
 
416
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
417
- gr.Text(value=msg))
 
418
 
419
  chat_inputs = [
420
  self.reference_image,
421
  self.edit_image,
422
- self.task_type,
423
  self.edit_type,
424
  self.cfg_scale,
425
  self.step,
426
  self.seed,
427
  self.output_height,
428
  self.output_width,
429
- self.repainting_scale
 
 
 
430
  ]
431
 
432
  chat_outputs = [
433
- self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
434
- self.edit_preprocess_mask_preview, self.generation_info_preview
 
 
435
  ]
436
 
437
  self.chat_btn.click(run_chat,
@@ -445,23 +440,26 @@ class DemoUI(object):
445
  queue=True)
446
 
447
  @spaces.GPU(duration=80)
448
- def run_example(task_type, edit_type, prompt, ref_image, edit_image, edit_mask,
449
- output_h, output_w, seed, progress=gr.Progress(track_tqdm=True)):
450
- model_path = self.task_model[task_type]["MODEL_PATH"]
 
451
 
452
  step = self.pipe.input.get("sample_steps", 20)
453
  cfg_scale = self.pipe.input.get("guide_scale", 20)
454
-
455
  edit_info = self.edit_type_dict[edit_type]
456
 
457
  edit_image = self.construct_edit_image(edit_image, edit_mask)
458
 
459
- pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
460
- pre_edit_image = edit_preprocess(edit_info, we.device_id, pre_edit_image, pre_edit_mask)
 
 
 
461
  edit_info = edit_info or {}
462
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
463
  st = time.time()
464
- image, seed = self.pipe(
465
  reference_image=pre_ref_image,
466
  edit_image=pre_edit_image,
467
  edit_mask=pre_edit_mask,
@@ -473,40 +471,52 @@ class DemoUI(object):
473
  guide_scale=cfg_scale,
474
  seed=seed,
475
  repainting_scale=repainting_scale,
476
- lora_path=model_path
 
 
477
  )
478
  et = time.time()
479
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
480
  if pre_edit_image is not None:
481
- ret_image = Image.composite(Image.new("RGB", pre_edit_image.size, (0, 0, 0)), pre_edit_image, pre_edit_mask)
 
482
  else:
483
  ret_image = None
 
 
 
484
  return (gr.Image(value=image), gr.Column(visible=True),
485
- gr.Image(value=pre_edit_image if pre_edit_image is not None else pre_ref_image),
 
486
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
487
  gr.Text(value=msg),
488
- gr.update(value=ret_image))
 
489
 
490
  with self.eg:
491
  self.example_edit_image = gr.Image(label='Edit Image',
492
- type='pil',
493
- image_mode='RGB',
494
- visible=False)
495
  self.example_edit_mask = gr.Image(label='Edit Image Mask',
496
- type='pil',
497
- image_mode='L',
498
- visible=False)
499
 
500
  self.examples = gr.Examples(
501
  fn=run_example,
502
  examples=self.all_examples,
503
  inputs=[
504
- self.task_type, self.edit_type, self.text, self.reference_image, self.example_edit_image,
505
- self.example_edit_mask, self.output_height, self.output_width, self.seed
 
506
  ],
507
  outputs=[self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
508
- self.edit_preprocess_mask_preview, self.generation_info_preview, self.edit_image],
509
- examples_per_page=6,
 
 
 
510
  cache_examples=False,
511
  run_on_click=True)
512
 
 
17
  from scepter.modules.utils.config import Config
18
  from scepter.modules.utils.distribute import we
19
  from scepter.modules.utils.file_system import FS
20
+ from examples.examples import fft_examples
21
+ from inference.registry import INFERENCES
22
  from inference.utils import edit_preprocess
 
 
23
 
 
 
 
24
 
25
  fs_list = [
26
  Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
 
33
  FS.init_fs_client(one_fs)
34
 
35
  os.environ["FLUX_FILL_PATH"]="hf://black-forest-labs/FLUX.1-Fill-dev"
36
+ os.environ["ACE_PLUS_FFT_MODEL"]="hf://ali-vilab/ACE_Plus@ace_plus_fft.safetensors"
 
 
37
 
38
  FS.get_dir_to_local_dir(os.environ["FLUX_FILL_PATH"])
39
+ FS.get_from(os.environ["ACE_PLUS_FFT_MODEL"])
 
 
 
40
 
41
  csv.field_size_limit(sys.maxsize)
42
  refresh_sty = '\U0001f504' # 🔄
 
50
  class DemoUI(object):
51
  #@spaces.GPU(duration=60)
52
  def __init__(self,
53
+ infer_dir="./config/ace_plus_fft.yaml"
 
54
  ):
55
+ self.model_yamls = [infer_dir]
 
56
  self.model_choices = dict()
57
  self.default_model_name = ''
58
+ self.edit_type_dict = {}
59
+ self.edit_type_list = []
60
+ self.default_type_list = []
61
  for i in self.model_yamls:
62
  model_cfg = Config(load=True, cfg_file=i)
63
+ model_name = model_cfg.VERSION
64
  if model_cfg.IS_DEFAULT: self.default_model_name = model_name
65
  self.model_choices[model_name] = model_cfg
66
+ for preprocessor in model_cfg.get("PREPROCESSOR", []):
67
+ if preprocessor["TYPE"] in self.edit_type_dict:
68
+ continue
69
+ self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
70
+ self.default_type_list.append(preprocessor["TYPE"])
71
  print('Models: ', self.model_choices.keys())
72
  assert len(self.model_choices) > 0
73
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
74
  self.model_name = self.default_model_name
75
  pipe_cfg = self.model_choices[self.default_model_name]
76
+ self.pipe = INFERENCES.build(pipe_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # reformat examples
78
  self.all_examples = [
79
  [
80
+ one_example["edit_type"], one_example["instruction"],
81
+ one_example["input_reference_image"], one_example["input_image"],
82
+ one_example["input_mask"], one_example["output_h"],
83
+ one_example["output_w"], one_example["seed"]
84
+ ]
85
+ for one_example in fft_examples
86
  ]
87
 
88
  def construct_edit_image(self, edit_image, edit_mask):
 
105
  else:
106
  return None
107
 
 
 
 
108
  def create_ui(self):
109
  with gr.Row(equal_height=True, visible=True):
110
  with gr.Column(scale=2):
 
121
  height=600,
122
  interactive=False,
123
  type='pil',
124
+ elem_id='preprocess_image',
125
+ label='edit image'
126
  )
127
 
128
  self.edit_preprocess_mask_preview = gr.Image(
129
  height=600,
130
  interactive=False,
131
  type='pil',
132
+ elem_id='preprocess_image_mask',
133
+ label='edit mask'
134
+ )
135
+
136
+ self.change_preprocess_preview = gr.Image(
137
+ height=600,
138
+ interactive=False,
139
+ type='pil',
140
+ elem_id='preprocess_change_image',
141
+ label='change image'
142
  )
143
  with gr.Row():
144
  instruction = """
145
  **Instruction**:
146
+ Users can perform reference generation or editing tasks by uploading reference images
147
+ and editing images. When uploading the editing image, various editing types are available
148
+ for selection. Users can choose different dimensions of information preservation,
149
+ such as edge information, color information, and more. Pre-processing information
150
+ can be viewed in the 'related input image' tab.
 
 
151
  """
152
  self.instruction = gr.Markdown(value=instruction)
153
+ with gr.Row():
154
+ self.icon = gr.Image(
155
+ value=None,
156
+ interactive=False,
157
+ height=150,
158
+ type='pil',
159
+ elem_id='icon',
160
+ label='icon'
161
+ )
162
  with gr.Row():
163
  self.model_name_dd = gr.Dropdown(
164
  choices=self.model_choices,
165
  value=self.default_model_name,
166
  label='Model Version')
167
+ self.edit_type = gr.Dropdown(choices=self.default_type_list,
168
  interactive=True,
169
+ value=self.default_type_list[0],
 
 
 
 
170
  label='Edit Type')
171
+ with gr.Row():
172
+ self.step = gr.Slider(minimum=1,
173
+ maximum=1000,
174
+ value=self.pipe.input.get("sample_steps", 20),
175
+ visible=self.pipe.input.get("sample_steps", None) is not None,
176
+ label='Sample Step')
177
+ self.cfg_scale = gr.Slider(
178
+ minimum=1.0,
179
+ maximum=100.0,
180
+ value=self.pipe.input.get("guide_scale", 4.5),
181
+ visible=self.pipe.input.get("guide_scale", None) is not None,
182
+ label='Guidance Scale')
183
+ self.seed = gr.Slider(minimum=-1,
184
+ maximum=1000000000000,
185
+ value=-1,
186
+ label='Seed')
187
+ self.output_height = gr.Slider(
188
+ minimum=256,
189
+ maximum=1440,
190
+ value=self.pipe.input.get("image_size", [1024, 1024])[0],
191
+ visible=self.pipe.input.get("image_size", None) is not None,
192
+ label='Output Height')
193
+ self.output_width = gr.Slider(
194
+ minimum=256,
195
+ maximum=1440,
196
+ value=self.pipe.input.get("image_size", [1024, 1024])[1],
197
+ visible=self.pipe.input.get("image_size", None) is not None,
198
+ label='Output Width')
199
+
200
+ self.repainting_scale = gr.Slider(
201
+ minimum=0.0,
202
+ maximum=1.0,
203
+ value=self.pipe.input.get("repainting_scale", 1.0),
204
+ visible=True,
205
+ label='Repainting Scale')
206
+ self.use_change = gr.Checkbox(
207
+ value=self.pipe.input.get("use_change", True),
208
+ visible=True,
209
+ label='Use Change')
210
+ self.keep_pixel = gr.Checkbox(
211
+ value=self.pipe.input.get("keep_pixel", True),
212
+ visible=True,
213
+ label='Keep Pixels')
214
+ self.keep_pixels_rate = gr.Slider(
215
+ minimum=0.5,
216
+ maximum=1.0,
217
+ value=0.8,
218
+ visible=True,
219
+ label='keep_pixel rate')
220
  with gr.Row():
221
  self.generation_info_preview = gr.Markdown(
222
  label='System Log.',
 
229
  placeholder='Input "@" find history of image',
230
  label='Instruction',
231
  container=False,
232
+ lines=1)
233
  with gr.Column(scale=2, min_width=100):
234
  with gr.Row():
235
  with gr.Column(scale=1, min_width=100):
236
+ self.chat_btn = gr.Button(value='Generate', variant="primary")
237
 
238
  with gr.Accordion(label='Advance', open=True):
239
  with gr.Row(visible=True):
 
260
  format="png"
261
  )
262
 
263
+ with gr.Row():
264
+ self.eg = gr.Column(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  def set_callbacks(self, *args, **kwargs):
267
  ########################################
 
276
  torch.cuda.empty_cache()
277
  torch.cuda.ipc_collect()
278
  pipe_cfg = self.model_choices[model_name]
279
+ self.pipe = INFERENCES.build(pipe_cfg)
 
 
280
  self.model_name = model_name
281
  lock.release()
282
 
283
  return (model_name, gr.update(),
284
  gr.Slider(
285
+ value=self.pipe.input.get("sample_steps", 20),
286
+ visible=self.pipe.input.get("sample_steps", None) is not None),
287
  gr.Slider(
288
  value=self.pipe.input.get("guide_scale", 4.5),
289
  visible=self.pipe.input.get("guide_scale", None) is not None),
290
  gr.Slider(
291
+ value=self.pipe.input.get("image_size", [1024, 1024])[0],
292
+ visible=self.pipe.input.get("image_size", None) is not None),
293
  gr.Slider(
294
+ value=self.pipe.input.get("image_size", [1024, 1024])[1],
295
+ visible=self.pipe.input.get("image_size", None) is not None),
296
  gr.Slider(value=self.pipe.input.get("repainting_scale", 1.0))
297
  )
298
 
 
307
  self.output_width,
308
  self.repainting_scale])
309
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  def change_edit_type(edit_type):
311
  edit_info = self.edit_type_dict[edit_type]
312
  edit_info = edit_info or {}
313
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
314
+ return gr.Slider(value=repainting_scale)
 
 
 
 
315
 
316
  self.edit_type.change(change_edit_type, inputs=[self.edit_type], outputs=[self.repainting_scale])
317
 
318
+ def resize_image(image, h):
319
+ ow, oh = image.size
320
+ w = int(h * ow / oh)
321
+ image = image.resize((w, h), Image.LANCZOS)
322
+ return image
323
+
324
+ def preprocess_input(ref_image, edit_image_dict, preprocess=None):
325
  err_msg = ""
326
  is_suc = True
327
  if ref_image is not None:
 
337
  edit_image = None
338
  edit_mask = None
339
  elif np.sum(np.array(edit_mask)) < 1:
340
+ edit_image = pillow_convert(edit_image, "RGB")
341
+ w, h = edit_image.size
342
+ edit_mask = Image.new("L", (w, h), 255)
343
  else:
344
  edit_image = pillow_convert(edit_image, "RGB")
345
  edit_mask = Image.fromarray(edit_mask).convert('L')
 
347
  err_msg = "Please provide the reference image or edited image."
348
  return None, None, None, False, err_msg
349
  return edit_image, edit_mask, ref_image, is_suc, err_msg
350
+
351
  @spaces.GPU(duration=80)
352
  def run_chat(
353
+ prompt,
354
+ ref_image,
355
+ edit_image,
356
+ edit_type,
357
+ cfg_scale,
358
+ step,
359
+ seed,
360
+ output_h,
361
+ output_w,
362
+ repainting_scale,
363
+ use_change,
364
+ keep_pixel,
365
+ keep_pixels_rate,
366
+ progress=gr.Progress(track_tqdm=True)
367
  ):
 
 
368
  edit_info = self.edit_type_dict[edit_type]
 
 
 
 
 
 
 
 
369
  pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
370
+ icon = pre_edit_image or pre_ref_image
371
  if not is_suc:
372
  err_msg = f"<mark>{err_msg}</mark>"
373
  return (gr.Image(), gr.Column(visible=True),
374
+ gr.Image(),
375
  gr.Image(),
376
  gr.Image(),
377
  gr.Text(value=err_msg))
378
+ pre_edit_image = edit_preprocess(edit_info.ANNOTATOR, we.device_id, pre_edit_image, pre_edit_mask)
379
  # edit_image["background"] = pre_edit_image
380
  st = time.time()
381
+ image, edit_image, change_image, mask, seed = self.pipe(
382
  reference_image=pre_ref_image,
383
  edit_image=pre_edit_image,
384
  edit_mask=pre_edit_mask,
 
390
  guide_scale=cfg_scale,
391
  seed=seed,
392
  repainting_scale=repainting_scale,
393
+ use_change=use_change,
394
+ keep_pixels=keep_pixel,
395
+ keep_pixels_rate=keep_pixels_rate
396
  )
397
  et = time.time()
398
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
399
 
400
+ if icon is not None:
401
+ icon = resize_image(icon, 150)
402
+
403
  return (gr.Image(value=image), gr.Column(visible=True),
404
+ gr.Image(value=edit_image if edit_image is not None else edit_image),
405
+ gr.Image(value=change_image),
406
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
407
+ gr.Text(value=msg),
408
+ gr.Image(value=icon))
409
 
410
  chat_inputs = [
411
  self.reference_image,
412
  self.edit_image,
 
413
  self.edit_type,
414
  self.cfg_scale,
415
  self.step,
416
  self.seed,
417
  self.output_height,
418
  self.output_width,
419
+ self.repainting_scale,
420
+ self.use_change,
421
+ self.keep_pixel,
422
+ self.keep_pixels_rate
423
  ]
424
 
425
  chat_outputs = [
426
+ self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
427
+ self.change_preprocess_preview,
428
+ self.edit_preprocess_mask_preview, self.generation_info_preview,
429
+ self.icon
430
  ]
431
 
432
  self.chat_btn.click(run_chat,
 
440
  queue=True)
441
 
442
  @spaces.GPU(duration=80)
443
+ def run_example(edit_type, prompt, ref_image, edit_image, edit_mask,
444
+ output_h, output_w, seed, use_change, keep_pixel,
445
+ keep_pixels_rate,
446
+ progress=gr.Progress(track_tqdm=True)):
447
 
448
  step = self.pipe.input.get("sample_steps", 20)
449
  cfg_scale = self.pipe.input.get("guide_scale", 20)
 
450
  edit_info = self.edit_type_dict[edit_type]
451
 
452
  edit_image = self.construct_edit_image(edit_image, edit_mask)
453
 
454
+ pre_edit_image, pre_edit_mask, pre_ref_image, _, _ = preprocess_input(ref_image, edit_image)
455
+
456
+ icon = pre_edit_image or pre_ref_image
457
+
458
+ pre_edit_image = edit_preprocess(edit_info.ANNOTATOR, we.device_id, pre_edit_image, pre_edit_mask)
459
  edit_info = edit_info or {}
460
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
461
  st = time.time()
462
+ image, edit_image, change_image, mask, seed = self.pipe(
463
  reference_image=pre_ref_image,
464
  edit_image=pre_edit_image,
465
  edit_mask=pre_edit_mask,
 
471
  guide_scale=cfg_scale,
472
  seed=seed,
473
  repainting_scale=repainting_scale,
474
+ use_change=use_change,
475
+ keep_pixels=keep_pixel,
476
+ keep_pixels_rate=keep_pixels_rate
477
  )
478
  et = time.time()
479
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
480
  if pre_edit_image is not None:
481
+ ret_image = Image.composite(Image.new("RGB", pre_edit_image.size, (0, 0, 0)), pre_edit_image,
482
+ pre_edit_mask)
483
  else:
484
  ret_image = None
485
+
486
+ if icon is not None:
487
+ icon = resize_image(icon, 150)
488
  return (gr.Image(value=image), gr.Column(visible=True),
489
+ gr.Image(value=edit_image if edit_image is not None else edit_image),
490
+ gr.Image(value=change_image),
491
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
492
  gr.Text(value=msg),
493
+ gr.update(value=ret_image),
494
+ gr.Image(value=icon))
495
 
496
  with self.eg:
497
  self.example_edit_image = gr.Image(label='Edit Image',
498
+ type='pil',
499
+ image_mode='RGB',
500
+ visible=False)
501
  self.example_edit_mask = gr.Image(label='Edit Image Mask',
502
+ type='pil',
503
+ image_mode='L',
504
+ visible=False)
505
 
506
  self.examples = gr.Examples(
507
  fn=run_example,
508
  examples=self.all_examples,
509
  inputs=[
510
+ self.edit_type, self.text, self.reference_image, self.example_edit_image,
511
+ self.example_edit_mask, self.output_height, self.output_width, self.seed,
512
+ self.use_change, self.keep_pixel, self.keep_pixels_rate
513
  ],
514
  outputs=[self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
515
+ self.change_preprocess_preview,
516
+ self.edit_preprocess_mask_preview, self.generation_info_preview,
517
+ self.edit_image,
518
+ self.icon],
519
+ examples_per_page=15,
520
  cache_examples=False,
521
  run_on_click=True)
522