Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,858 Bytes
37840e7 5514789 ea9a6b2 5514789 37840e7 c283f36 5514789 5212158 c5c5a80 5212158 5514789 956147e 6af6ea2 5212158 d2eaa46 ea9a6b2 d2eaa46 5514789 c283f36 5514789 36f850c 5514789 c283f36 5514789 d2eaa46 5514789 c283f36 5514789 c283f36 5514789 c08d09e c283f36 7e8f68e 5514789 c283f36 5514789 c08d09e 5514789 c5c5a80 5212158 5514789 5212158 5514789 8f8d235 5514789 c283f36 5514789 5212158 5514789 c5c5a80 8f8d235 5514789 c283f36 5212158 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
torch.jit.script = lambda f: f
from zoedepth.utils.misc import colorize, save_raw_16bit
from zoedepth.utils.geometry import depth_to_points, create_triangles
from marigold_depth_estimation import MarigoldPipeline
import gradio as gr
import spaces
from PIL import Image
import numpy as np
import trimesh
from functools import partial
import tempfile
css = """
#img-display-container {
max-height: 50vh;
}
#img-display-input {
max-height: 40vh;
}
#img-display-output {
max-height: 40vh;
}
"""
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cuda'
model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
CHECKPOINT = "prs-eth/marigold-v1-0"
pipe = MarigoldPipeline.from_pretrained(CHECKPOINT)
# ----------- Depth functions
@spaces.GPU(enable_queue=True)
def save_raw_16bit(depth, fpath="raw.png"):
if isinstance(depth, torch.Tensor):
depth = depth.squeeze().cpu().numpy()
assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
assert depth.ndim == 2, "Depth must be 2D"
depth = depth * 256 # scale for 16-bit png
depth = depth.astype(np.uint16)
return depth
@spaces.GPU(enable_queue=True)
def process_image(image: Image.Image):
global model
image = image.convert("RGB")
device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# depth = model.infer_pil(image)
# processed_array = save_raw_16bit(colorize(depth)[:, :, 0])
# return Image.fromarray(processed_array)
model.to(device)
# # inference
processed_array = pipe(image)["depth"]
return Image.fromarray(processed_array)
# ----------- Depth functions
# ----------- Mesh functions
@spaces.GPU(enable_queue=True)
def depth_edges_mask(depth):
global model
"""Returns a mask of edges in the depth map.
Args:
depth: 2D numpy array of shape (H, W) with dtype float32.
Returns:
mask: 2D numpy array of shape (H, W) with dtype bool.
"""
# Compute the x and y gradients of the depth map.
depth_dx, depth_dy = np.gradient(depth)
# Compute the gradient magnitude.
depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
# Compute the edge mask.
mask = depth_grad > 0.05
return mask
@spaces.GPU(enable_queue=True)
def predict_depth(image):
global model
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
depth = model.infer_pil(image)
return depth
@spaces.GPU(enable_queue=True)
def get_mesh(image: Image.Image, keep_edges=True):
image.thumbnail((1024,1024)) # limit the size of the input image
depth = predict_depth(image)
pts3d = depth_to_points(depth[None])
pts3d = pts3d.reshape(-1, 3)
# Create a trimesh mesh from the points
# Each pixel is connected to its 4 neighbors
# colors are the RGB values of the image
verts = pts3d.reshape(-1, 3)
image = np.array(image)
if keep_edges:
triangles = create_triangles(image.shape[0], image.shape[1])
else:
triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
colors = image.reshape(-1, 3)
mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
# Save as glb
glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
glb_path = glb_file.name
mesh.export(glb_path)
return glb_path
# ----------- Mesh functions
title = "# ZoeDepth"
description = """Unofficial demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
with gr.Blocks(css=css) as API:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tab("Depth Prediction"):
with gr.Row():
inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
generate_btn = gr.Button(value="Generate")
# generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
generate_btn.click(process_image, inputs=inputs, outputs=outputs, api_name="generate_depth")
with gr.Tab("Image to 3D"):
with gr.Row():
with gr.Column():
inputs=[gr.Image(label="Input Image", type='pil', height=500), gr.Checkbox(label="Keep occlusion edges", value=True)]
outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], height=500)
generate_btn = gr.Button(value="Generate")
# generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
generate_btn.click(get_mesh, inputs=inputs, outputs=outputs, api_name="generate_mesh")
if __name__ == '__main__':
API.launch() |