Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,755 Bytes
cf8c487 41a69fa cf8c487 dc9d69c 404967f ee04d83 f95c546 40c89eb 4bfe855 40c89eb 024a2b8 f95c546 dc9d69c cf8c487 f95c546 764b436 f95c546 cf8c487 f95c546 cf8c487 501d06f f95c546 501d06f f95c546 501d06f f95c546 501d06f f95c546 ee04d83 501d06f f95c546 84a6150 f95c546 84a6150 f95c546 764b436 cf8c487 ee04d83 404967f f95c546 ee04d83 501d06f f95c546 ee04d83 404967f f95c546 404967f f95c546 404967f f95c546 404967f f95c546 404967f f95c546 404967f d59dba7 4bfe855 d59dba7 4bfe855 f95c546 d59dba7 f95c546 404967f f95c546 404967f b0d730c 182cf21 b0d730c f95c546 404967f f95c546 182cf21 f95c546 84a6150 f95c546 84a6150 404967f f95c546 84a6150 ee04d83 f95c546 ee04d83 cf8c487 c6f3d95 d893f72 4bfe855 cf8c487 182cf21 4bfe855 182cf21 f95c546 4bfe855 f95c546 84a6150 4bfe855 f95c546 cf8c487 f95c546 c6f3d95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm # Add this import
import subprocess
import cv2 # Add this import
from datetime import datetime
# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")
# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])
# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device) # Move the model to the selected device
model.eval() # Set the model to evaluation mode
def resize_image(image_path, max_size=1024):
"""
Resize the input image to ensure its largest dimension does not exceed max_size.
Maintains the aspect ratio and saves the resized image as a temporary PNG file.
Args:
image_path (str): Path to the input image.
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
Returns:
str: Path to the resized temporary image file.
"""
with Image.open(image_path) as img:
# Calculate the resizing ratio while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image using LANCZOS filter for high-quality downsampling
img = img.resize(new_size, Image.LANCZOS)
# Save the resized image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
def generate_3d_model(depth, image_path, focallength_px):
"""
Generate a textured 3D mesh from the depth map and the original image.
Args:
depth (np.ndarray): 2D array representing depth in meters.
image_path (str): Path to the resized RGB image.
focallength_px (float): Focal length in pixels.
Returns:
tuple: Paths to the exported 3D model files for viewing and downloading.
"""
# Load the RGB image and convert to a NumPy array
image = np.array(Image.open(image_path))
height, width = depth.shape
# Compute camera intrinsic parameters
fx = fy = focallength_px # Assuming square pixels and fx = fy
cx, cy = width / 2, height / 2 # Principal point at the image center
# Create a grid of (u, v) pixel coordinates
u = np.arange(0, width)
v = np.arange(0, height)
uu, vv = np.meshgrid(u, v)
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
Z = depth.flatten()
X = ((uu.flatten() - cx) * Z) / fx
Y = ((vv.flatten() - cy) * Z) / fy
# Stack the coordinates to form vertices (X, Y, Z)
vertices = np.vstack((X, Y, Z)).T
# Normalize RGB colors to [0, 1] for vertex coloring
colors = image.reshape(-1, 3) / 255.0
# Generate faces by connecting adjacent vertices to form triangles
faces = []
for i in range(height - 1):
for j in range(width - 1):
idx = i * width + j
# Triangle 1
faces.append([idx, idx + width, idx + 1])
# Triangle 2
faces.append([idx + 1, idx + width, idx + width + 1])
faces = np.array(faces)
# Create the mesh using Trimesh with vertex colors
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
# Export the mesh to OBJ files with unique filenames
timestamp = int(time.time())
view_model_path = f'view_model_{timestamp}.obj'
download_model_path = f'download_model_{timestamp}.obj'
mesh.export(view_model_path)
mesh.export(download_model_path)
return view_model_path, download_model_path
@spaces.GPU(duration=20)
def predict_depth(input_image):
temp_file = None
try:
# Resize the input image to a manageable size
temp_file = resize_image(input_image)
# Preprocess the image for depth prediction
result = depth_pro.load_rgb(temp_file)
image = result[0]
f_px = result[-1] # Focal length in pixels
image = transform(image) # Apply preprocessing transforms
image = image.to(device) # Move the image tensor to the selected device
# Run the depth prediction model
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"] # Depth map in meters
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth from torch tensor to NumPy array if necessary
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Ensure the depth map is a 2D array
if depth.ndim != 2:
depth = depth.squeeze()
# Print debug information
print(f"Original depth shape: {depth.shape}")
print(f"Original image shape: {image.shape}")
# Resize depth to match image dimensions
image_height, image_width = image.shape[2], image.shape[3]
depth = cv2.resize(depth, (image_width, image_height), interpolation=cv2.INTER_LINEAR)
print(f"Resized depth shape: {depth.shape}")
print(f"Final image shape: {image.shape}")
# No downsampling
downscale_factor = 1
# Convert image tensor to CPU and NumPy
image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
# No normalization of depth map as it is already in meters
depth_min = np.min(depth)
depth_max = np.max(depth)
depth_normalized = depth # Depth remains in meters
# Create a color map for visualization using matplotlib
plt.figure(figsize=(10, 10))
plt.imshow(depth_normalized, cmap='gist_rainbow')
plt.colorbar(label='Depth [m]')
plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
plt.axis('off') # Hide axis for a cleaner image
# Save the depth map visualization to a file
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
# Save the raw depth data to a CSV file for download
raw_depth_path = "raw_depth_map.csv"
np.savetxt(raw_depth_path, depth, delimiter=',')
# Generate the 3D model from the depth map and resized image
view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
except Exception as e:
# Return error messages in case of failures
return None, f"An error occurred: {str(e)}", None, None, None
finally:
# Clean up by removing the temporary resized image file
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
def get_last_commit_timestamp():
try:
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
return "Unknown"
# Create the Gradio interface with appropriate input and output components.
last_updated = get_last_commit_timestamp()
iface = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="filepath"),
outputs=[
gr.Image(type="filepath", label="Depth Map"),
gr.Textbox(label="Focal Length or Error Message"),
gr.File(label="Download Raw Depth Map (CSV)"),
gr.Model3D(label="View 3D Model"),
gr.File(label="Download 3D Model (OBJ)")
],
title="DepthPro Demo with 3D Visualization",
description=(
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
"**Instructions:**\n"
"1. Upload an image.\n"
"2. The app will predict the depth map, display it, and provide the focal length.\n"
"3. Download the raw depth data as a CSV file.\n"
"4. View the generated 3D model textured with the original image.\n"
"5. Download the 3D model as an OBJ file if desired.\n\n"
f"Last updated: {last_updated}"
),
)
# Launch the Gradio interface with sharing enabled
iface.launch(share=True) # share=True allows you to share the interface with others. |