wyysf commited on
Commit
9bc9474
1 Parent(s): dd81b16

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +7 -2
gradio_app.py CHANGED
@@ -107,9 +107,10 @@ era3d_pipeline = None
107
  def gen_mvimg(
108
  mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
109
  ):
 
110
  if seed == 0:
111
  seed = np.random.randint(1, 65535)
112
- global generator, device
113
  generator = torch.Generator(device)
114
  generator.manual_seed(seed)
115
 
@@ -144,6 +145,9 @@ def gen_mvimg(
144
 
145
  elif mvimg_model == "Era3D":
146
  global era3d_pipeline
 
 
 
147
  crop_size = 420
148
  batch = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
149
  crop_size=crop_size, single_image=image, prompt_embeds_path='apps/third_party/Era3D/data/fixed_prompt_embeds_6view')[0]
@@ -254,7 +258,8 @@ if __name__=="__main__":
254
  # schema = OmegaConf.structured(TestConfig)
255
  # cfg = OmegaConf.merge(schema, cfg)
256
  era3d_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
257
- 'pengHTYX/MacLab-Era3D-512-6view'
 
258
  )
259
  # enable xformers
260
  # era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
 
107
  def gen_mvimg(
108
  mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
109
  ):
110
+ global device
111
  if seed == 0:
112
  seed = np.random.randint(1, 65535)
113
+ global generator
114
  generator = torch.Generator(device)
115
  generator.manual_seed(seed)
116
 
 
145
 
146
  elif mvimg_model == "Era3D":
147
  global era3d_pipeline
148
+ era3d_pipeline.to(device=device)
149
+ era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
150
+
151
  crop_size = 420
152
  batch = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
153
  crop_size=crop_size, single_image=image, prompt_embeds_path='apps/third_party/Era3D/data/fixed_prompt_embeds_6view')[0]
 
258
  # schema = OmegaConf.structured(TestConfig)
259
  # cfg = OmegaConf.merge(schema, cfg)
260
  era3d_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
261
+ 'pengHTYX/MacLab-Era3D-512-6view',
262
+ dtype=torch.float16,
263
  )
264
  # enable xformers
265
  # era3d_pipeline.unet.enable_xformers_memory_efficient_attention()