pengHTYX commited on
Commit
a72a0f3
·
1 Parent(s): 2736f7e

'update_layout'

Browse files
app.py CHANGED
@@ -141,23 +141,20 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
141
  input_image = Image.fromarray((rgb * 255).astype(np.uint8))
142
  else:
143
  input_image = expand2square(input_image, (127, 127, 127, 0))
144
- return input_image, input_image.resize((768, 768), Image.Resampling.LANCZOS)
145
 
146
 
147
  def load_era3d_pipeline(cfg):
148
  # Load scheduler, tokenizer and models.
149
 
150
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
151
- cfg.pretrained_model_name_or_path,
152
- torch_dtype=weight_dtype
153
  )
154
 
155
- # pipeline.to('cuda:0')
156
- pipeline.unet.enable_xformers_memory_efficient_attention()
157
-
158
-
159
  if torch.cuda.is_available():
160
  pipeline.to('cuda:0')
 
161
  # sys.main_lock = threading.Lock()
162
  return pipeline
163
 
@@ -165,8 +162,9 @@ def load_era3d_pipeline(cfg):
165
  from mvdiffusion.data.single_image_dataset import SingleImageDataset
166
 
167
 
168
- def prepare_data(single_image, crop_size):
169
- dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white', crop_size=crop_size, single_image=single_image)
 
170
  return dataset[0]
171
 
172
  scene = 'scene'
@@ -179,7 +177,7 @@ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_
179
  if chk_group is not None:
180
  write_image = "Write Results" in chk_group
181
 
182
- batch = prepare_data(single_image, crop_size)
183
 
184
  pipeline.set_progress_bar_config(disable=True)
185
  seed = int(seed)
@@ -203,7 +201,7 @@ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_
203
  guidance_scale=guidance_scale,
204
  output_type='pt',
205
  num_images_per_prompt=1,
206
- return_elevation_focal=cfg.log_elevation_focal_length,
207
  **cfg.pipe_validation_kwargs
208
  ).images
209
 
@@ -314,6 +312,7 @@ def run_demo():
314
  custom_css = '''#disp_image {
315
  text-align: center; /* Horizontally center the content */
316
  }'''
 
317
 
318
  with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
319
  with gr.Row():
@@ -322,14 +321,16 @@ def run_demo():
322
  gr.Markdown(_DESCRIPTION)
323
  with gr.Row(variant='panel'):
324
  with gr.Column(scale=1):
325
- input_image = gr.Image(type='pil', image_mode='RGBA', height=768, label='Input image')
326
 
327
  with gr.Column(scale=1):
 
 
328
  processed_image = gr.Image(
329
  type='pil',
330
  label="Processed Image",
331
  interactive=False,
332
- height=768,
333
  image_mode='RGBA',
334
  elem_id="disp_image",
335
  visible=True,
@@ -341,7 +342,7 @@ def run_demo():
341
  # label="3D Model", height=320,
342
  # # camera_position=[0,0,2.0]
343
  # )
344
- processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
345
  with gr.Row(variant='panel'):
346
  with gr.Column(scale=1):
347
  example_folder = os.path.join(os.path.dirname(__file__), "./examples")
@@ -391,6 +392,7 @@ def run_demo():
391
  view_1 = gr.Image(interactive=False, height=512, show_label=False)
392
  view_2 = gr.Image(interactive=False, height=512, show_label=False)
393
  view_3 = gr.Image(interactive=False, height=512, show_label=False)
 
394
  view_4 = gr.Image(interactive=False, height=512, show_label=False)
395
  view_5 = gr.Image(interactive=False, height=512, show_label=False)
396
  view_6 = gr.Image(interactive=False, height=512, show_label=False)
@@ -398,10 +400,11 @@ def run_demo():
398
  normal_1 = gr.Image(interactive=False, height=512, show_label=False)
399
  normal_2 = gr.Image(interactive=False, height=512, show_label=False)
400
  normal_3 = gr.Image(interactive=False, height=512, show_label=False)
 
401
  normal_4 = gr.Image(interactive=False, height=512, show_label=False)
402
  normal_5 = gr.Image(interactive=False, height=512, show_label=False)
403
  normal_6 = gr.Image(interactive=False, height=512, show_label=False)
404
-
405
  run_btn.click(
406
  fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True
407
  ).success(
@@ -414,7 +417,7 @@ def run_demo():
414
  # )
415
 
416
  demo.queue().launch(share=True, max_threads=80)
417
-
418
 
419
  if __name__ == '__main__':
420
  fire.Fire(run_demo)
 
141
  input_image = Image.fromarray((rgb * 255).astype(np.uint8))
142
  else:
143
  input_image = expand2square(input_image, (127, 127, 127, 0))
144
+ return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
145
 
146
 
147
  def load_era3d_pipeline(cfg):
148
  # Load scheduler, tokenizer and models.
149
 
150
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
151
+ cfg.pretrained_model_name_or_path,
152
+ torch_dtype=weight_dtype
153
  )
154
 
 
 
 
 
155
  if torch.cuda.is_available():
156
  pipeline.to('cuda:0')
157
+ pipeline.unet.enable_xformers_memory_efficient_attention()
158
  # sys.main_lock = threading.Lock()
159
  return pipeline
160
 
 
162
  from mvdiffusion.data.single_image_dataset import SingleImageDataset
163
 
164
 
165
+ def prepare_data(single_image, crop_size, cfg):
166
+ dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
167
+ crop_size=crop_size, single_image=single_image, prompt_embeds_path=cfg.validation_dataset.prompt_embeds_path)
168
  return dataset[0]
169
 
170
  scene = 'scene'
 
177
  if chk_group is not None:
178
  write_image = "Write Results" in chk_group
179
 
180
+ batch = prepare_data(single_image, crop_size, cfg)
181
 
182
  pipeline.set_progress_bar_config(disable=True)
183
  seed = int(seed)
 
201
  guidance_scale=guidance_scale,
202
  output_type='pt',
203
  num_images_per_prompt=1,
204
+ # return_elevation_focal=cfg.log_elevation_focal_length,
205
  **cfg.pipe_validation_kwargs
206
  ).images
207
 
 
312
  custom_css = '''#disp_image {
313
  text-align: center; /* Horizontally center the content */
314
  }'''
315
+
316
 
317
  with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
318
  with gr.Row():
 
321
  gr.Markdown(_DESCRIPTION)
322
  with gr.Row(variant='panel'):
323
  with gr.Column(scale=1):
324
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image')
325
 
326
  with gr.Column(scale=1):
327
+ processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
328
+
329
  processed_image = gr.Image(
330
  type='pil',
331
  label="Processed Image",
332
  interactive=False,
333
+ # height=320,
334
  image_mode='RGBA',
335
  elem_id="disp_image",
336
  visible=True,
 
342
  # label="3D Model", height=320,
343
  # # camera_position=[0,0,2.0]
344
  # )
345
+
346
  with gr.Row(variant='panel'):
347
  with gr.Column(scale=1):
348
  example_folder = os.path.join(os.path.dirname(__file__), "./examples")
 
392
  view_1 = gr.Image(interactive=False, height=512, show_label=False)
393
  view_2 = gr.Image(interactive=False, height=512, show_label=False)
394
  view_3 = gr.Image(interactive=False, height=512, show_label=False)
395
+ with gr.Row():
396
  view_4 = gr.Image(interactive=False, height=512, show_label=False)
397
  view_5 = gr.Image(interactive=False, height=512, show_label=False)
398
  view_6 = gr.Image(interactive=False, height=512, show_label=False)
 
400
  normal_1 = gr.Image(interactive=False, height=512, show_label=False)
401
  normal_2 = gr.Image(interactive=False, height=512, show_label=False)
402
  normal_3 = gr.Image(interactive=False, height=512, show_label=False)
403
+ with gr.Row():
404
  normal_4 = gr.Image(interactive=False, height=512, show_label=False)
405
  normal_5 = gr.Image(interactive=False, height=512, show_label=False)
406
  normal_6 = gr.Image(interactive=False, height=512, show_label=False)
407
+ print('Launching...')
408
  run_btn.click(
409
  fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True
410
  ).success(
 
417
  # )
418
 
419
  demo.queue().launch(share=True, max_threads=80)
420
+
421
 
422
  if __name__ == '__main__':
423
  fire.Fire(run_demo)
mvdiffusion/data/single_image_dataset.py CHANGED
@@ -236,10 +236,10 @@ class SingleImageDataset(Dataset):
236
  color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None
237
 
238
  out = {
239
- 'imgs_in': img_tensors_in,
240
- 'alphas': alpha_tensors_in,
241
- 'normal_prompt_embeddings': normal_prompt_embeddings,
242
- 'color_prompt_embeddings': color_prompt_embeddings,
243
  'filename': filename,
244
  }
245
 
 
236
  color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None
237
 
238
  out = {
239
+ 'imgs_in': img_tensors_in.unsqueeze(0),
240
+ 'alphas': alpha_tensors_in.unsqueeze(0),
241
+ 'normal_prompt_embeddings': normal_prompt_embeddings.unsqueeze(0),
242
+ 'color_prompt_embeddings': color_prompt_embeddings.unsqueeze(0),
243
  'filename': filename,
244
  }
245
 
mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py CHANGED
@@ -239,7 +239,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
239
  image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
240
 
241
  # _____________________________vae input latents__________________________________________________
242
- image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
243
  image_pt = image_pt * 2.0 - 1.0
244
  image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
245
  # Note: repeat differently from official pipelines
 
239
  image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
240
 
241
  # _____________________________vae input latents__________________________________________________
242
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(dtype=self.vae.dtype, device=device)
243
  image_pt = image_pt * 2.0 - 1.0
244
  image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
245
  # Note: repeat differently from official pipelines
requirements.txt CHANGED
@@ -30,7 +30,7 @@ torch_efficient_distloss
30
  tensorboard
31
  rembg
32
  segment_anything
33
- gradio==3.50.2
34
  moviepy
35
  kornia
36
  fire
 
30
  tensorboard
31
  rembg
32
  segment_anything
33
+ gradio==4.29.0
34
  moviepy
35
  kornia
36
  fire