jiaweir commited on
Commit
61e8d2c
·
1 Parent(s): 76c0bbe

add examples

Browse files
Files changed (3) hide show
  1. app.py +22 -15
  2. data/catstatue_rgba.png +0 -0
  3. data/zelda_rgba.png +0 -0
app.py CHANGED
@@ -8,7 +8,6 @@ import hashlib
8
  import shlex
9
 
10
  subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
11
- # subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118"))
12
 
13
  import rembg
14
  import glob
@@ -51,7 +50,7 @@ function refresh() {
51
 
52
 
53
  device = torch.device('cuda')
54
- # device = torch.device('cpu')
55
 
56
  session = rembg.new_session(model_name='u2net')
57
 
@@ -160,6 +159,8 @@ def check_img_input(control_image):
160
 
161
  # check if there is a picture uploaded or selected
162
  def check_video_3d_input(image_block: Image.Image):
 
 
163
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
164
  if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
165
  raise gr.Error("Please generate a video first")
@@ -212,7 +213,7 @@ def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider
212
  # stage 1
213
  # subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
214
  process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings, seed_slider)
215
- # return [os.path.join('logs', 'tmp_rgba_model.ply')]
216
  return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
217
 
218
  @spaces.GPU(duration=120)
@@ -241,11 +242,12 @@ if __name__ == "__main__":
241
  '''
242
  _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D** (they can run in parallel). Finally, click **Generate 4D**."
243
 
244
- # load images in 'data' folder as examples
245
  example_folder = os.path.join(os.path.dirname(__file__), 'data')
246
- example_fns = os.listdir(example_folder)
247
- example_fns.sort()
248
- examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
 
 
249
 
250
  # Compose demo layout & data flow
251
  with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
@@ -268,18 +270,22 @@ if __name__ == "__main__":
268
  preprocess_chk = gr.Checkbox(True,
269
  label='Preprocess image automatically (remove background and recenter object)')
270
 
 
 
 
 
 
 
 
 
271
  gr.Examples(
272
  examples=examples_full, # NOTE: elements must match inputs list!
273
- inputs=[image_block],
274
  outputs=[image_block],
275
  cache_examples=False,
276
- label='Examples (click one of the images below to start)',
277
  examples_per_page=40
278
  )
279
- img_run_btn = gr.Button("Generate Video")
280
- threed_run_btn = gr.Button("Generate 3D")
281
- fourd_run_btn = gr.Button("Generate 4D")
282
- img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
283
 
284
  with gr.Column(scale=5):
285
  with gr.Row():
@@ -287,8 +293,9 @@ if __name__ == "__main__":
287
  dirving_video = gr.Video(label="video",height=290)
288
  with gr.Column(scale=5):
289
  obj3d = gr.Video(label="3D Model",height=290)
290
- video4d = gr.Video(label="4D video",height=290)
291
- obj4d = Model4DGS(label="4D Model", height=500, fps=28)
 
292
 
293
 
294
  img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
 
8
  import shlex
9
 
10
  subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
 
11
 
12
  import rembg
13
  import glob
 
50
 
51
 
52
  device = torch.device('cuda')
53
+ # # device = torch.device('cpu')
54
 
55
  session = rembg.new_session(model_name='u2net')
56
 
 
159
 
160
  # check if there is a picture uploaded or selected
161
  def check_video_3d_input(image_block: Image.Image):
162
+ if image_block is None:
163
+ raise gr.Error("Please select or upload an input image")
164
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
165
  if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
166
  raise gr.Error("Please generate a video first")
 
213
  # stage 1
214
  # subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
215
  process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings, seed_slider)
216
+ # return os.path.join('logs', f'{img_hash}_rgba_model.ply')
217
  return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
218
 
219
  @spaces.GPU(duration=120)
 
242
  '''
243
  _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D** (they can run in parallel). Finally, click **Generate 4D**."
244
 
 
245
  example_folder = os.path.join(os.path.dirname(__file__), 'data')
246
+ examples_full = [
247
+ [example_folder+'/csm_luigi_rgba.png', 10],
248
+ [example_folder+'/anya_rgba.png', 42],
249
+ [example_folder+'/panda.png', 42262],
250
+ ]
251
 
252
  # Compose demo layout & data flow
253
  with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
 
270
  preprocess_chk = gr.Checkbox(True,
271
  label='Preprocess image automatically (remove background and recenter object)')
272
 
273
+ with gr.Row():
274
+ with gr.Column(scale=5):
275
+ img_run_btn = gr.Button("Generate Video")
276
+ with gr.Column(scale=5):
277
+ threed_run_btn = gr.Button("Generate 3D")
278
+ fourd_run_btn = gr.Button("Generate 4D")
279
+ img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
280
+
281
  gr.Examples(
282
  examples=examples_full, # NOTE: elements must match inputs list!
283
+ inputs=[image_block, seed_slider],
284
  outputs=[image_block],
285
  cache_examples=False,
286
+ label='Examples (click one of the examples below to start)',
287
  examples_per_page=40
288
  )
 
 
 
 
289
 
290
  with gr.Column(scale=5):
291
  with gr.Row():
 
293
  dirving_video = gr.Video(label="video",height=290)
294
  with gr.Column(scale=5):
295
  obj3d = gr.Video(label="3D Model",height=290)
296
+ # obj3d = gr.Model3D(label="3D Model",height=290)
297
+ video4d = gr.Video(label="4D Render",height=290)
298
+ obj4d = Model4DGS(label="4D Model", height=500, fps=21)
299
 
300
 
301
  img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
data/catstatue_rgba.png DELETED
Binary file (48.1 kB)
 
data/zelda_rgba.png DELETED
Binary file (47.2 kB)