fffiloni commited on
Commit
32c9458
1 Parent(s): cdc3395

add infer_compo

Browse files
Files changed (1) hide show
  1. app.py +131 -17
app.py CHANGED
@@ -272,6 +272,113 @@ def infer(ref_style_file, style_description, caption):
272
  # Reset the state after inference, regardless of success or failure
273
  reset_inference_state()
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  import gradio as gr
276
 
277
  with gr.Blocks() as demo:
@@ -289,27 +396,34 @@ with gr.Blocks() as demo:
289
  </div>
290
  """)
291
  with gr.Row():
292
- style_reference_image = gr.Image(
293
- label = "Style Reference Image",
294
- type = "filepath"
295
- )
296
- style_description = gr.Textbox(
297
- label ="Style Description"
298
- )
299
- subject_prompt = gr.Textbox(
300
- label = "Subject Prompt"
301
- )
302
- with gr.Accordion("Advanced Settings", open=False):
303
- subject_reference = gr.Image(type="filepath")
304
- use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
305
- submit_btn = gr.Button("Submit")
306
- with gr.Row():
307
- output_image = gr.Image(label="Output Image")
308
-
 
309
  submit_btn.click(
310
  fn = infer,
311
  inputs = [style_reference_image, style_description, subject_prompt],
312
  outputs = [output_image]
313
  )
 
 
 
 
 
 
314
 
315
  demo.launch()
 
272
  # Reset the state after inference, regardless of success or failure
273
  reset_inference_state()
274
 
275
+ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
276
+ global models_rbm, models_b
277
+ try:
278
+ caption = f"{caption} in {style_description}"
279
+ sam_prompt = f"{caption}"
280
+ use_sam_mask = False
281
+
282
+ if low_vram:
283
+ # Revert the devices of the modules back to their original state
284
+ models_to(models_rbm, device)
285
+
286
+ batch_size = 1
287
+ height, width = 1024, 1024
288
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
289
+
290
+ extras.sampling_configs['cfg'] = 4
291
+ extras.sampling_configs['shift'] = 2
292
+ extras.sampling_configs['timesteps'] = 20
293
+ extras.sampling_configs['t_start'] = 1.0
294
+ extras_b.sampling_configs['cfg'] = 1.1
295
+ extras_b.sampling_configs['shift'] = 1
296
+ extras_b.sampling_configs['timesteps'] = 10
297
+ extras_b.sampling_configs['t_start'] = 1.0
298
+
299
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
300
+ ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
301
+
302
+ batch = {'captions': [caption] * batch_size}
303
+ batch['style'] = ref_style
304
+ batch['images'] = ref_images
305
+
306
+ x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
307
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
308
+
309
+ ## SAM Mask for sub
310
+ use_sam_mask = False
311
+ x0_preview = models_rbm.previewer(x0_forward)
312
+ sam_model = LangSAM()
313
+ sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
314
+ sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
315
+
316
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
317
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False, eval_subject_style=True)
318
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
319
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
320
+
321
+ if low_vram:
322
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
323
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
324
+ models_to(sam_model, device="cpu")
325
+ models_to(sam_model.sam, device="cpu")
326
+
327
+ # Stage C reverse process.
328
+ sampling_c = extras.gdf.sample(
329
+ models_rbm.generator, conditions, stage_c_latent_shape,
330
+ unconditions, device=device,
331
+ **extras.sampling_configs,
332
+ x0_style_forward=x0_style_forward, x0_forward=x0_forward,
333
+ apply_pushforward=False, tau_pushforward=5, tau_pushforward_csd=10,
334
+ num_iter=3, eta=1e-1, tau=20, eval_sub_csd=True,
335
+ extras=extras, models=models_rbm,
336
+ use_attn_mask=use_sam_mask,
337
+ save_attn_mask=False,
338
+ lam_content=1, lam_style=1,
339
+ sam_mask=sam_mask, use_sam_mask=use_sam_mask,
340
+ sam_prompt=sam_prompt
341
+ )
342
+
343
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
344
+ sampled_c = sampled_c
345
+
346
+ # Stage B reverse process.
347
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
348
+ conditions_b['effnet'] = sampled_c
349
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
350
+
351
+ sampling_b = extras_b.gdf.sample(
352
+ models_b.generator, conditions_b, stage_b_latent_shape,
353
+ unconditions_b, device=device, **extras_b.sampling_configs,
354
+ )
355
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
356
+ sampled_b = sampled_b
357
+ sampled = models_b.stage_a.decode(sampled_b).float()
358
+
359
+ sampled = torch.cat([
360
+ torch.nn.functional.interpolate(ref_images.cpu(), size=(height, width)),
361
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
362
+ sampled.cpu(),
363
+ ], dim=0)
364
+
365
+ # Remove the batch dimension and keep only the generated image
366
+ sampled = sampled[2] # This selects the generated image, discarding the reference images
367
+
368
+ # Ensure the tensor is in [C, H, W] format
369
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
370
+ output_file = 'output_compo.png'
371
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
372
+ sampled_image.save(output_file) # Save the image as a PNG
373
+ else:
374
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
375
+
376
+ return output_file # Return the path to the saved image
377
+
378
+ finally:
379
+ # Reset the state after inference, regardless of success or failure
380
+ reset_inference_state()
381
+
382
  import gradio as gr
383
 
384
  with gr.Blocks() as demo:
 
396
  </div>
397
  """)
398
  with gr.Row():
399
+ with gr.Column():
400
+ style_reference_image = gr.Image(
401
+ label = "Style Reference Image",
402
+ type = "filepath"
403
+ )
404
+ style_description = gr.Textbox(
405
+ label ="Style Description"
406
+ )
407
+ subject_prompt = gr.Textbox(
408
+ label = "Subject Prompt"
409
+ )
410
+ with gr.Accordion("Advanced Settings", open=False):
411
+ subject_reference = gr.Image(type="filepath")
412
+ use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
413
+ submit_btn = gr.Button("Submit")
414
+ with gr.Column():
415
+ output_image = gr.Image(label="Output Image")
416
+ '''
417
  submit_btn.click(
418
  fn = infer,
419
  inputs = [style_reference_image, style_description, subject_prompt],
420
  outputs = [output_image]
421
  )
422
+ '''
423
+ submit_btn.click(
424
+ fn = infer_compo,
425
+ inputs = [style_description, style_reference_image, subject_prompt, subject_reference],
426
+ outputs = [output_image]
427
+ )
428
 
429
  demo.launch()