Spaces:
Running
on
Zero
Running
on
Zero
First try to add 3D model creation
Browse files
app.py
CHANGED
@@ -8,100 +8,205 @@ import spaces
|
|
8 |
import torch
|
9 |
import tempfile
|
10 |
import os
|
|
|
11 |
|
12 |
-
# Run the script to
|
13 |
subprocess.run(["bash", "get_pretrained_models.sh"])
|
14 |
|
|
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
-
# Load model and preprocessing
|
18 |
model, transform = depth_pro.create_model_and_transforms()
|
19 |
-
model = model.to(device)
|
20 |
-
model.eval()
|
21 |
|
22 |
def resize_image(image_path, max_size=1024):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
with Image.open(image_path) as img:
|
24 |
-
# Calculate the
|
25 |
ratio = max_size / max(img.size)
|
26 |
new_size = tuple([int(x * ratio) for x in img.size])
|
27 |
|
28 |
-
# Resize the image
|
29 |
img = img.resize(new_size, Image.LANCZOS)
|
30 |
|
31 |
-
#
|
32 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
33 |
img.save(temp_file, format="PNG")
|
34 |
return temp_file.name
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
@spaces.GPU(duration=20)
|
37 |
def predict_depth(input_image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
temp_file = None
|
39 |
try:
|
40 |
-
# Resize the input image
|
41 |
temp_file = resize_image(input_image)
|
42 |
|
43 |
-
# Preprocess the image
|
44 |
result = depth_pro.load_rgb(temp_file)
|
45 |
image = result[0]
|
46 |
-
f_px = result[-1] #
|
47 |
-
image = transform(image)
|
48 |
-
image = image.to(device)
|
49 |
|
50 |
-
# Run
|
51 |
prediction = model.infer(image, f_px=f_px)
|
52 |
-
depth = prediction["depth"] # Depth in
|
53 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
54 |
|
55 |
-
# Convert depth to
|
56 |
if isinstance(depth, torch.Tensor):
|
57 |
depth = depth.cpu().numpy()
|
58 |
|
59 |
-
# Ensure depth is a 2D
|
60 |
if depth.ndim != 2:
|
61 |
depth = depth.squeeze()
|
62 |
|
63 |
-
#
|
64 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
depth_min = np.min(depth)
|
66 |
depth_max = np.max(depth)
|
67 |
-
depth_normalized = depth
|
68 |
-
|
69 |
-
# Create a color map
|
70 |
plt.figure(figsize=(10, 10))
|
71 |
plt.imshow(depth_normalized, cmap='gist_rainbow')
|
72 |
plt.colorbar(label='Depth [m]')
|
73 |
plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
|
74 |
-
plt.axis('off')
|
75 |
-
|
76 |
-
# Save the
|
77 |
output_path = "depth_map.png"
|
78 |
plt.savefig(output_path)
|
79 |
plt.close()
|
80 |
|
81 |
-
# Save raw depth data
|
82 |
raw_depth_path = "raw_depth_map.csv"
|
83 |
np.savetxt(raw_depth_path, depth, delimiter=',')
|
84 |
|
85 |
-
|
|
|
|
|
|
|
86 |
except Exception as e:
|
87 |
-
|
|
|
88 |
finally:
|
89 |
-
# Clean up the temporary file
|
90 |
if temp_file and os.path.exists(temp_file):
|
91 |
os.remove(temp_file)
|
92 |
|
93 |
-
# Create Gradio interface
|
94 |
iface = gr.Interface(
|
95 |
fn=predict_depth,
|
96 |
inputs=gr.Image(type="filepath"),
|
97 |
outputs=[
|
98 |
-
gr.Image(type="filepath", label="Depth Map"),
|
99 |
-
gr.Textbox(label="Focal Length or Error Message"),
|
100 |
-
gr.File(label="Download Raw Depth Map (CSV)")
|
|
|
101 |
],
|
102 |
-
title="DepthPro Demo
|
103 |
-
description=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
)
|
105 |
|
106 |
-
# Launch the interface
|
107 |
-
iface.launch(share=True)
|
|
|
8 |
import torch
|
9 |
import tempfile
|
10 |
import os
|
11 |
+
import trimesh
|
12 |
|
13 |
+
# Run the script to download pretrained models
|
14 |
subprocess.run(["bash", "get_pretrained_models.sh"])
|
15 |
|
16 |
+
# Set the device to GPU if available, else CPU
|
17 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18 |
|
19 |
+
# Load the depth prediction model and its preprocessing transforms
|
20 |
model, transform = depth_pro.create_model_and_transforms()
|
21 |
+
model = model.to(device) # Move the model to the selected device
|
22 |
+
model.eval() # Set the model to evaluation mode
|
23 |
|
24 |
def resize_image(image_path, max_size=1024):
|
25 |
+
"""
|
26 |
+
Resize the input image to ensure its largest dimension does not exceed max_size.
|
27 |
+
Maintains the aspect ratio and saves the resized image as a temporary PNG file.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
image_path (str): Path to the input image.
|
31 |
+
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: Path to the resized temporary image file.
|
35 |
+
"""
|
36 |
with Image.open(image_path) as img:
|
37 |
+
# Calculate the resizing ratio while maintaining aspect ratio
|
38 |
ratio = max_size / max(img.size)
|
39 |
new_size = tuple([int(x * ratio) for x in img.size])
|
40 |
|
41 |
+
# Resize the image using LANCZOS filter for high-quality downsampling
|
42 |
img = img.resize(new_size, Image.LANCZOS)
|
43 |
|
44 |
+
# Save the resized image to a temporary file
|
45 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
46 |
img.save(temp_file, format="PNG")
|
47 |
return temp_file.name
|
48 |
|
49 |
+
def generate_3d_model(depth, image_path, focallength_px):
|
50 |
+
"""
|
51 |
+
Generate a textured 3D mesh from the depth map and the original image.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
depth (np.ndarray): 2D array representing depth in meters.
|
55 |
+
image_path (str): Path to the resized RGB image.
|
56 |
+
focallength_px (float): Focal length in pixels.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: Path to the exported 3D model file in OBJ format.
|
60 |
+
"""
|
61 |
+
# Load the RGB image and convert to a NumPy array
|
62 |
+
image = np.array(Image.open(image_path))
|
63 |
+
height, width = depth.shape
|
64 |
+
|
65 |
+
# Compute camera intrinsic parameters
|
66 |
+
fx = fy = focallength_px # Assuming square pixels and fx = fy
|
67 |
+
cx, cy = width / 2, height / 2 # Principal point at the image center
|
68 |
+
|
69 |
+
# Create a grid of (u, v) pixel coordinates
|
70 |
+
u = np.arange(0, width)
|
71 |
+
v = np.arange(0, height)
|
72 |
+
uu, vv = np.meshgrid(u, v)
|
73 |
+
|
74 |
+
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
|
75 |
+
Z = depth.flatten()
|
76 |
+
X = ((uu.flatten() - cx) * Z) / fx
|
77 |
+
Y = ((vv.flatten() - cy) * Z) / fy
|
78 |
+
|
79 |
+
# Stack the coordinates to form vertices (X, Y, Z)
|
80 |
+
vertices = np.vstack((X, Y, Z)).T
|
81 |
+
|
82 |
+
# Normalize RGB colors to [0, 1] for vertex coloring
|
83 |
+
colors = image.reshape(-1, 3) / 255.0
|
84 |
+
|
85 |
+
# Generate faces by connecting adjacent vertices to form triangles
|
86 |
+
faces = []
|
87 |
+
for i in range(height - 1):
|
88 |
+
for j in range(width - 1):
|
89 |
+
idx = i * width + j
|
90 |
+
# Triangle 1
|
91 |
+
faces.append([idx, idx + width, idx + 1])
|
92 |
+
# Triangle 2
|
93 |
+
faces.append([idx + 1, idx + width, idx + width + 1])
|
94 |
+
faces = np.array(faces)
|
95 |
+
|
96 |
+
# Create the mesh using Trimesh with vertex colors
|
97 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
|
98 |
+
|
99 |
+
# Export the mesh to an OBJ file
|
100 |
+
model_path = 'output_model.obj'
|
101 |
+
mesh.export(model_path)
|
102 |
+
return model_path
|
103 |
+
|
104 |
@spaces.GPU(duration=20)
|
105 |
def predict_depth(input_image):
|
106 |
+
"""
|
107 |
+
Predict the depth map from the input image, generate visualizations and a 3D model.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
input_image (str): Path to the input image file.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
tuple:
|
114 |
+
- str: Path to the depth map image.
|
115 |
+
- str: Focal length in pixels or an error message.
|
116 |
+
- str: Path to the raw depth data CSV file.
|
117 |
+
- str: Path to the generated 3D model file.
|
118 |
+
"""
|
119 |
temp_file = None
|
120 |
try:
|
121 |
+
# Resize the input image to a manageable size
|
122 |
temp_file = resize_image(input_image)
|
123 |
|
124 |
+
# Preprocess the image for depth prediction
|
125 |
result = depth_pro.load_rgb(temp_file)
|
126 |
image = result[0]
|
127 |
+
f_px = result[-1] # Focal length in pixels
|
128 |
+
image = transform(image) # Apply preprocessing transforms
|
129 |
+
image = image.to(device) # Move the image tensor to the selected device
|
130 |
|
131 |
+
# Run the depth prediction model
|
132 |
prediction = model.infer(image, f_px=f_px)
|
133 |
+
depth = prediction["depth"] # Depth map in meters
|
134 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
135 |
|
136 |
+
# Convert depth from torch tensor to NumPy array if necessary
|
137 |
if isinstance(depth, torch.Tensor):
|
138 |
depth = depth.cpu().numpy()
|
139 |
|
140 |
+
# Ensure the depth map is a 2D array
|
141 |
if depth.ndim != 2:
|
142 |
depth = depth.squeeze()
|
143 |
|
144 |
+
# **Downsample depth map and image to improve processing speed**
|
145 |
+
downscale_factor = 2 # Factor by which to downscale (e.g., 2 reduces dimensions by half)
|
146 |
+
depth = depth[::downscale_factor, ::downscale_factor]
|
147 |
+
# Convert image tensor to CPU and NumPy for slicing
|
148 |
+
image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
|
149 |
+
image_ds = image_np[::downscale_factor, ::downscale_factor, :]
|
150 |
+
# Update focal length based on downscaling
|
151 |
+
focallength_px = focallength_px / downscale_factor
|
152 |
+
|
153 |
+
# **Note:** The downscaled image is saved back to the temporary file for consistency
|
154 |
+
downscaled_image = Image.fromarray((image_ds * 255).astype(np.uint8))
|
155 |
+
downscaled_image.save(temp_file)
|
156 |
+
|
157 |
+
# No normalization of depth map as it is already in meters
|
158 |
depth_min = np.min(depth)
|
159 |
depth_max = np.max(depth)
|
160 |
+
depth_normalized = depth # Depth remains in meters
|
161 |
+
|
162 |
+
# Create a color map for visualization using matplotlib
|
163 |
plt.figure(figsize=(10, 10))
|
164 |
plt.imshow(depth_normalized, cmap='gist_rainbow')
|
165 |
plt.colorbar(label='Depth [m]')
|
166 |
plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
|
167 |
+
plt.axis('off') # Hide axis for a cleaner image
|
168 |
+
|
169 |
+
# Save the depth map visualization to a file
|
170 |
output_path = "depth_map.png"
|
171 |
plt.savefig(output_path)
|
172 |
plt.close()
|
173 |
|
174 |
+
# Save the raw depth data to a CSV file for download
|
175 |
raw_depth_path = "raw_depth_map.csv"
|
176 |
np.savetxt(raw_depth_path, depth, delimiter=',')
|
177 |
|
178 |
+
# Generate the 3D model from the depth map and resized image
|
179 |
+
model_path = generate_3d_model(depth, temp_file, focallength_px)
|
180 |
+
|
181 |
+
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, model_path
|
182 |
except Exception as e:
|
183 |
+
# Return error messages in case of failures
|
184 |
+
return None, f"An error occurred: {str(e)}", None, None
|
185 |
finally:
|
186 |
+
# Clean up by removing the temporary resized image file
|
187 |
if temp_file and os.path.exists(temp_file):
|
188 |
os.remove(temp_file)
|
189 |
|
190 |
+
# Create the Gradio interface with appropriate input and output components
|
191 |
iface = gr.Interface(
|
192 |
fn=predict_depth,
|
193 |
inputs=gr.Image(type="filepath"),
|
194 |
outputs=[
|
195 |
+
gr.Image(type="filepath", label="Depth Map"), # Displays the depth map image
|
196 |
+
gr.Textbox(label="Focal Length or Error Message"), # Shows focal length or error messages
|
197 |
+
gr.File(label="Download Raw Depth Map (CSV)"), # Allows downloading the raw depth data
|
198 |
+
gr.Model3D(label="3D Model") # Displays the generated 3D model
|
199 |
],
|
200 |
+
title="DepthPro Demo with 3D Visualization",
|
201 |
+
description=(
|
202 |
+
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
|
203 |
+
"**Instructions:**\n"
|
204 |
+
"1. Upload an image.\n"
|
205 |
+
"2. The app will predict the depth map, display it, and provide the focal length.\n"
|
206 |
+
"3. Download the raw depth data as a CSV file.\n"
|
207 |
+
"4. View the generated 3D model textured with the original image."
|
208 |
+
),
|
209 |
)
|
210 |
|
211 |
+
# Launch the Gradio interface with sharing enabled
|
212 |
+
iface.launch(share=True) # share=True allows you to share the interface with others.
|