haotongl commited on
Commit
fa2db85
·
1 Parent(s): 21731aa

inital version

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -22,7 +22,8 @@ from promptda.utils.depth_utils import visualize_depth, unproject_depth
22
  DEVICE = 'cuda'
23
  # if torch.cuda.is_available(
24
  # ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
25
- model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval()
 
26
  thread_pool_executor = ThreadPoolExecutor(max_workers=1)
27
 
28
  def delete_later(path: Union[str, os.PathLike], delay: int = 300):
@@ -44,10 +45,11 @@ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
44
  atexit.register(_delete)
45
 
46
 
47
- # @spaces.GPU
48
  def run_with_gpu(image, prompt_depth):
49
  image = image.to(DEVICE)
50
  prompt_depth = prompt_depth.to(DEVICE)
 
51
  depth = model.predict(image, prompt_depth)
52
  depth = depth[0, 0].detach().cpu().numpy()
53
  return depth
@@ -56,7 +58,6 @@ def check_is_stray_scanner_app_capture(input_dir):
56
  assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found'
57
  pass
58
 
59
- @spaces.GPU
60
  def run(input_file, resolution):
61
  # unzip zip file
62
  input_file = input_file.name
 
22
  DEVICE = 'cuda'
23
  # if torch.cuda.is_available(
24
  # ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
25
+ # model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval()
26
+ model = PromptDA.from_pretrained('depth-anything/promptda_vitl').eval()
27
  thread_pool_executor = ThreadPoolExecutor(max_workers=1)
28
 
29
  def delete_later(path: Union[str, os.PathLike], delay: int = 300):
 
45
  atexit.register(_delete)
46
 
47
 
48
+ @spaces.GPU
49
  def run_with_gpu(image, prompt_depth):
50
  image = image.to(DEVICE)
51
  prompt_depth = prompt_depth.to(DEVICE)
52
+ model.to(DEVICE)
53
  depth = model.predict(image, prompt_depth)
54
  depth = depth[0, 0].detach().cpu().numpy()
55
  return depth
 
58
  assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found'
59
  pass
60
 
 
61
  def run(input_file, resolution):
62
  # unzip zip file
63
  input_file = input_file.name