qubvel-hf HF staff commited on
Commit
42ade60
1 Parent(s): 0811d77
Files changed (2) hide show
  1. app.py +98 -52
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,16 +1,18 @@
1
- import gradio as gr
2
- import cv2
 
3
  import matplotlib
 
4
  import numpy as np
5
- import os
 
6
  from PIL import Image
7
- import spaces
8
- import torch
9
- import tempfile
10
- from gradio_imageslider import ImageSlider
11
  from huggingface_hub import hf_hub_download
 
12
 
13
  from depth_anything_v2.dpt import DepthAnythingV2
 
14
 
15
  css = """
16
  #img-display-container {
@@ -26,78 +28,122 @@ css = """
26
  height: 62px;
27
  }
28
  """
 
 
 
 
29
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
- model_configs = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
33
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
  'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
35
  }
36
- encoder2name = {
37
  'vits': 'Small',
38
  'vitb': 'Base',
39
  'vitl': 'Large',
40
- 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
41
  }
42
- encoder = 'vitl'
43
- model_name = encoder2name[encoder]
44
- model = DepthAnythingV2(**model_configs[encoder])
45
- filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
46
- state_dict = torch.load(filepath, map_location="cpu")
47
- model.load_state_dict(state_dict)
48
- model = model.to(DEVICE).eval()
49
-
50
- title = "# Depth Anything V2"
51
- description1 = """Official demo for **Depth Anything V2**.
52
- Please refer to our [paper](https://arxiv.org/abs/2406.09414) for more details."""
53
- description2 = """**Due to the issue with our V2 Github repositories, we temporarily upload the content to [Huggingface space](https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/README_Github.md).**"""
 
 
 
 
 
 
 
54
 
55
  @spaces.GPU
56
- def predict_depth(image):
57
- return model.infer_image(image)
 
 
 
58
 
59
- with gr.Blocks(css=css) as demo:
60
- gr.Markdown(title)
61
- gr.Markdown(description1)
62
- gr.Markdown(description2)
63
- gr.Markdown("### Depth Prediction demo")
64
 
65
- with gr.Row():
66
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
67
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
68
- submit = gr.Button(value="Compute Depth")
69
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
70
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
71
 
72
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
73
 
74
- def on_submit(image):
75
- original_image = image.copy()
 
 
 
 
76
 
77
- h, w = image.shape[:2]
78
 
79
- depth = predict_depth(image[:, :, ::-1])
 
 
 
 
 
80
 
81
- raw_depth = Image.fromarray(depth.astype('uint16'))
82
- tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
83
- raw_depth.save(tmp_raw_depth.name)
84
 
85
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
86
- depth = depth.astype(np.uint8)
87
- colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
 
 
88
 
89
- gray_depth = Image.fromarray(depth)
90
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
91
- gray_depth.save(tmp_gray_depth.name)
92
 
93
- return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
 
96
 
97
  example_files = os.listdir('assets/examples')
98
  example_files.sort()
99
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
100
- examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
101
 
102
 
103
  if __name__ == '__main__':
 
1
+ import os
2
+ import torch
3
+ import spaces
4
  import matplotlib
5
+
6
  import numpy as np
7
+ import gradio as gr
8
+
9
  from PIL import Image
10
+ from transformers import pipeline
 
 
 
11
  from huggingface_hub import hf_hub_download
12
+ from gradio_imageslider import ImageSlider
13
 
14
  from depth_anything_v2.dpt import DepthAnythingV2
15
+ from loguru import logger
16
 
17
  css = """
18
  #img-display-container {
 
28
  height: 62px;
29
  }
30
  """
31
+
32
+ title = "# Depth Anything: Watch V1 and V2 side by side."
33
+ description1 = """Please refer to **Depth Anything V2** [paper](https://arxiv.org/abs/2406.09414) for more details."""
34
+
35
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
36
+ DEFAULT_V2_MODEL_NAME = "Base"
37
+ DEFAULT_V1_MODEL_NAME = "Base"
38
+
39
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
40
+
41
+ # --------------------------------------------------------------------
42
+ # Depth anything V1 configuration
43
+ # --------------------------------------------------------------------
44
+ depth_anything_v1_name2checkpoint = {
45
+ "Small": "LiheYoung/depth-anything-small-hf",
46
+ "Base": "LiheYoung/depth-anything-base-hf",
47
+ "Large": "LiheYoung/depth-anything-large-hf",
48
+ }
49
+
50
+ depth_anything_v1_pipelines = {}
51
+ # --------------------------------------------------------------------
52
+ # Depth anything V2 configuration
53
+ # --------------------------------------------------------------------
54
+
55
+ depth_anything_v2_configs = {
56
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
57
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
58
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
59
  'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
60
  }
61
+ depth_anything_v2_encoder2name = {
62
  'vits': 'Small',
63
  'vitb': 'Base',
64
  'vitl': 'Large',
65
+ # 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
66
  }
67
+ depth_anything_v2_name2encoder = {v: k for k, v in depth_anything_v2_encoder2name.items()}
68
+
69
+ depth_anything_v2_models = {}
70
+ # --------------------------------------------------------------------
71
+
72
+
73
+ def get_v1_pipe(model_name):
74
+ return pipeline(task="depth-estimation", model=depth_anything_v1_name2checkpoint[model_name], device=DEVICE)
75
+
76
+
77
+ def get_v2_model(model_name):
78
+ encoder = depth_anything_v2_name2encoder[model_name]
79
+ model = DepthAnythingV2(**depth_anything_v2_configs[encoder])
80
+ filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
81
+ state_dict = torch.load(filepath, map_location="cpu")
82
+ model.load_state_dict(state_dict)
83
+ model = model.to(DEVICE).eval()
84
+ return model
85
+
86
 
87
  @spaces.GPU
88
+ def predict_depth_v1(image, model_name):
89
+ if model_name not in depth_anything_v1_pipelines:
90
+ depth_anything_v1_pipelines[model_name] = get_v1_pipe(model_name)
91
+ pipe = depth_anything_v1_pipelines[model_name]
92
+ return pipe(image)
93
 
 
 
 
 
 
94
 
95
+ @spaces.GPU
96
+ def predict_depth_v2(image, model_name):
97
+ if model_name not in depth_anything_v2_models:
98
+ depth_anything_v2_models[model_name] = get_v2_model(model_name)
99
+ model = depth_anything_v2_models[model_name]
100
+ return model.infer_image(image)
101
 
 
102
 
103
+ def compute_depth_map_v2(image, model_select: str):
104
+ depth = predict_depth_v2(image[:, :, ::-1], model_select)
105
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
106
+ depth = depth.astype(np.uint8)
107
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
108
+ return colored_depth
109
 
 
110
 
111
+ def compute_depth_map_v1(image, model_select):
112
+ pil_image = Image.fromarray(image)
113
+ depth = predict_depth_v1(pil_image, model_select)
114
+ depth = np.array(depth["depth"]).astype(np.uint8)
115
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
116
+ return colored_depth
117
 
 
 
 
118
 
119
+ def on_submit(image, model_v1_select, model_v2_select):
120
+ logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
121
+ colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
122
+ colored_depth_v2 = compute_depth_map_v2(image, model_v2_select)
123
+ return colored_depth_v1, colored_depth_v2
124
 
 
 
 
125
 
126
+ with gr.Blocks(css=css) as demo:
127
+ gr.Markdown(title)
128
+ gr.Markdown(description1)
129
+ gr.Markdown("### Depth Prediction demo")
130
+ with gr.Row():
131
+ model_select_v1 = gr.Dropdown(label="Depth Anything V1 Model", choices=list(depth_anything_v1_name2checkpoint.keys()), value=DEFAULT_V1_MODEL_NAME)
132
+ model_select_v2 = gr.Dropdown(label="Depth Anything V2 Model", choices=list(depth_anything_v2_encoder2name.values()), value=DEFAULT_V2_MODEL_NAME)
133
+ with gr.Row():
134
+ gr.Markdown()
135
+ gr.Markdown("Depth Maps: V1 <-> V2")
136
+ with gr.Row():
137
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
138
+ depth_image_slider = ImageSlider(elem_id='img-display-output', position=0.5)
139
 
140
+ submit = gr.Button(value="Compute Depth")
141
+ submit.click(on_submit, inputs=[input_image, model_select_v1, model_select_v2], outputs=[depth_image_slider])
142
 
143
  example_files = os.listdir('assets/examples')
144
  example_files.sort()
145
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
146
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider], fn=on_submit)
147
 
148
 
149
  if __name__ == '__main__':
requirements.txt CHANGED
@@ -4,4 +4,7 @@ torch
4
  torchvision
5
  opencv-python
6
  matplotlib
7
- huggingface_hub
 
 
 
 
4
  torchvision
5
  opencv-python
6
  matplotlib
7
+ huggingface_hub
8
+ transformers
9
+ numpy==1.*
10
+ loguru