CharlieAmalet commited on
Commit
c283f36
·
verified ·
1 Parent(s): 8f8d235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import torch
2
  torch.jit.script = lambda f: f
3
- from zoedepth.utils.config import get_config
4
- from zoedepth.models.builder import build_model
5
  from zoedepth.utils.misc import colorize, save_raw_16bit
6
  from zoedepth.utils.geometry import depth_to_points, create_triangles
7
  import gradio as gr
8
  import spaces
 
9
  from PIL import Image
10
  import numpy as np
11
  import trimesh
@@ -31,6 +30,7 @@ DEVICE = 'cuda'
31
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
32
 
33
  # ----------- Depth functions
 
34
  def save_raw_16bit(depth, fpath="raw.png"):
35
  if isinstance(depth, torch.Tensor):
36
  depth = depth.squeeze().cpu().numpy()
@@ -42,7 +42,8 @@ def save_raw_16bit(depth, fpath="raw.png"):
42
  return depth
43
 
44
  @spaces.GPU(enable_queue=True)
45
- def process_image(model, image: Image.Image):
 
46
  image = image.convert("RGB")
47
 
48
  model.to(DEVICE)
@@ -54,8 +55,9 @@ def process_image(model, image: Image.Image):
54
  # ----------- Depth functions
55
 
56
  # ----------- Mesh functions
57
-
58
  def depth_edges_mask(depth):
 
59
  """Returns a mask of edges in the depth map.
60
  Args:
61
  depth: 2D numpy array of shape (H, W) with dtype float32.
@@ -72,12 +74,14 @@ def depth_edges_mask(depth):
72
 
73
  @spaces.GPU(enable_queue=True)
74
  def predict_depth(model, image):
 
75
  model.to(DEVICE)
76
  depth = model.infer_pil(image)
77
  return depth
78
 
79
  @spaces.GPU(enable_queue=True)
80
- def get_mesh(model, image: Image.Image, keep_edges=True):
 
81
  image.thumbnail((1024,1024)) # limit the size of the input image
82
 
83
  depth = predict_depth(model, image)
@@ -117,7 +121,8 @@ with gr.Blocks(css=css) as API:
117
  inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
118
  outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
119
  generate_btn = gr.Button(value="Generate")
120
- generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
 
121
 
122
  with gr.Tab("Image to 3D"):
123
  with gr.Row():
@@ -125,7 +130,8 @@ with gr.Blocks(css=css) as API:
125
  inputs=[gr.Image(label="Input Image", type='pil', height=500), gr.Checkbox(label="Keep occlusion edges", value=True)]
126
  outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], height=500)
127
  generate_btn = gr.Button(value="Generate")
128
- generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
 
129
 
130
  if __name__ == '__main__':
131
  API.launch()
 
1
  import torch
2
  torch.jit.script = lambda f: f
 
 
3
  from zoedepth.utils.misc import colorize, save_raw_16bit
4
  from zoedepth.utils.geometry import depth_to_points, create_triangles
5
  import gradio as gr
6
  import spaces
7
+
8
  from PIL import Image
9
  import numpy as np
10
  import trimesh
 
30
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
31
 
32
  # ----------- Depth functions
33
+ @spaces.GPU(enable_queue=True)
34
  def save_raw_16bit(depth, fpath="raw.png"):
35
  if isinstance(depth, torch.Tensor):
36
  depth = depth.squeeze().cpu().numpy()
 
42
  return depth
43
 
44
  @spaces.GPU(enable_queue=True)
45
+ def process_image(image: Image.Image):
46
+ global model
47
  image = image.convert("RGB")
48
 
49
  model.to(DEVICE)
 
55
  # ----------- Depth functions
56
 
57
  # ----------- Mesh functions
58
+ @spaces.GPU(enable_queue=True)
59
  def depth_edges_mask(depth):
60
+ global model
61
  """Returns a mask of edges in the depth map.
62
  Args:
63
  depth: 2D numpy array of shape (H, W) with dtype float32.
 
74
 
75
  @spaces.GPU(enable_queue=True)
76
  def predict_depth(model, image):
77
+ global model
78
  model.to(DEVICE)
79
  depth = model.infer_pil(image)
80
  return depth
81
 
82
  @spaces.GPU(enable_queue=True)
83
+ def get_mesh(image: Image.Image, keep_edges=True):
84
+ global model
85
  image.thumbnail((1024,1024)) # limit the size of the input image
86
 
87
  depth = predict_depth(model, image)
 
121
  inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
122
  outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
123
  generate_btn = gr.Button(value="Generate")
124
+ # generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
125
+ generate_btn.click(process_image, inputs=inputs, outputs=outputs, api_name="generate_depth")
126
 
127
  with gr.Tab("Image to 3D"):
128
  with gr.Row():
 
130
  inputs=[gr.Image(label="Input Image", type='pil', height=500), gr.Checkbox(label="Keep occlusion edges", value=True)]
131
  outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], height=500)
132
  generate_btn = gr.Button(value="Generate")
133
+ # generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
134
+ generate_btn.click(get_mesh, inputs=inputs, outputs=outputs, api_name="generate_mesh")
135
 
136
  if __name__ == '__main__':
137
  API.launch()