YulianSa commited on
Commit
508502c
·
1 Parent(s): 068de59
Files changed (1) hide show
  1. infer_api.py +4 -4
infer_api.py CHANGED
@@ -304,13 +304,13 @@ def save_image_numpy(ndarr):
304
  return im
305
 
306
  @spaces.GPU
307
- def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
308
  pipeline.unet.enable_xformers_memory_efficient_attention()
309
 
310
- if cfg.seed is None:
311
  generator = None
312
  else:
313
- generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
314
 
315
  images_cond = []
316
  results = {}
@@ -821,7 +821,7 @@ def infer_multiview_gen(img, seed, num_levels):
821
  data["normal_prompt_embeddings"] = infer_multiview_normal_text_embeds[None, ...]
822
  data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
823
 
824
- results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
825
  return results
826
 
827
  infer_canonicalize_config = {
 
304
  return im
305
 
306
  @spaces.GPU
307
+ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3, seed=None):
308
  pipeline.unet.enable_xformers_memory_efficient_attention()
309
 
310
+ if seed is None:
311
  generator = None
312
  else:
313
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
314
 
315
  images_cond = []
316
  results = {}
 
821
  data["normal_prompt_embeddings"] = infer_multiview_normal_text_embeds[None, ...]
822
  data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
823
 
824
+ results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels, seed=seed)
825
  return results
826
 
827
  infer_canonicalize_config = {