A19grey's picture
add open3d and let user set parameters for 3D model
f9c3dad
raw
history blame
12.4 kB
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm # Add this import
import subprocess
import cv2 # Add this import
from datetime import datetime
# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")
# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])
# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device) # Move the model to the selected device
model.eval() # Set the model to evaluation mode
def resize_image(image_path, max_size=1024):
"""
Resize the input image to ensure its largest dimension does not exceed max_size.
Maintains the aspect ratio and saves the resized image as a temporary PNG file.
Args:
image_path (str): Path to the input image.
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
Returns:
str: Path to the resized temporary image file.
"""
with Image.open(image_path) as img:
# Calculate the resizing ratio while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image using LANCZOS filter for high-quality downsampling
img = img.resize(new_size, Image.LANCZOS)
# Save the resized image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01):
"""
Generate a textured 3D mesh from the depth map and the original image.
"""
# Load the RGB image and convert to a NumPy array
image = np.array(Image.open(image_path))
# Ensure depth is a NumPy array
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Resize depth to match image dimensions if necessary
if depth.shape != image.shape[:2]:
depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
height, width = depth.shape
print(f"3D model generation - Depth shape: {depth.shape}")
print(f"3D model generation - Image shape: {image.shape}")
# Compute camera intrinsic parameters
fx = fy = float(focallength_px) # Ensure focallength_px is a float
cx, cy = width / 2, height / 2 # Principal point at the image center
# Create a grid of (u, v) pixel coordinates
u = np.arange(0, width)
v = np.arange(0, height)
uu, vv = np.meshgrid(u, v)
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
Z = depth.flatten()
X = ((uu.flatten() - cx) * Z) / fx
Y = ((vv.flatten() - cy) * Z) / fy
# Stack the coordinates to form vertices (X, Y, Z)
vertices = np.vstack((X, Y, Z)).T
# Normalize RGB colors to [0, 1] for vertex coloring
colors = image.reshape(-1, 3) / 255.0
# Generate faces by connecting adjacent vertices to form triangles
faces = []
for i in range(height - 1):
for j in range(width - 1):
idx = i * width + j
# Triangle 1
faces.append([idx, idx + width, idx + 1])
# Triangle 2
faces.append([idx + 1, idx + width, idx + width + 1])
faces = np.array(faces)
# Create the mesh using Trimesh with vertex colors
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
# Mesh cleaning and improvement steps
print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# 1. Mesh simplification
target_faces = int(len(mesh.faces) * simplification_factor)
mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# 2. Remove small disconnected components
components = mesh.split(only_watertight=False)
if len(components) > 1:
areas = np.array([c.area for c in components])
mesh = components[np.argmax(areas)]
print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# 3. Smooth the mesh
for _ in range(smoothing_iterations):
mesh = mesh.smoothed()
print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# 4. Remove thin features
mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# Export the mesh to OBJ files with unique filenames
timestamp = int(time.time())
view_model_path = f'view_model_{timestamp}.obj'
download_model_path = f'download_model_{timestamp}.obj'
mesh.export(view_model_path)
mesh.export(download_model_path)
return view_model_path, download_model_path
def remove_thin_features(mesh, thickness_threshold=0.01):
"""
Remove thin features from the mesh.
"""
# Calculate edge lengths
edges = mesh.edges_unique
edge_points = mesh.vertices[edges]
edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1)
# Identify short edges
short_edges = edges[edge_lengths < thickness_threshold]
# Collapse short edges
for edge in short_edges:
try:
mesh.collapse_edge(edge)
except:
pass # Skip if edge collapse fails
# Remove any newly created degenerate faces
mesh.remove_degenerate_faces()
return mesh
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
# Load depth from CSV
depth = np.loadtxt(depth_csv, delimiter=',')
# Generate new 3D model with updated parameters
view_model_path, download_model_path = generate_3d_model(
depth, image_path, focallength_px,
simplification_factor, smoothing_iterations, thin_threshold
)
return view_model_path, download_model_path
@spaces.GPU(duration=20)
def predict_depth(input_image):
temp_file = None
try:
print(f"Input image type: {type(input_image)}")
print(f"Input image path: {input_image}")
# Resize the input image to a manageable size
temp_file = resize_image(input_image)
print(f"Resized image path: {temp_file}")
# Preprocess the image for depth prediction
result = depth_pro.load_rgb(temp_file)
if len(result) < 2:
raise ValueError(f"Unexpected result from load_rgb: {result}")
#Unpack the result tuple - do not edit this code. Don't try to unpack differently.
image = result[0]
f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM
print(f"Extracted focal length: {f_px}")
image = transform(image).to(device)
# Run the depth prediction model
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"] # Depth map in meters
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth from torch tensor to NumPy array if necessary
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Ensure the depth map is a 2D array
if depth.ndim != 2:
depth = depth.squeeze()
print(f"Depth map shape: {depth.shape}")
# Create a color map for visualization using matplotlib
plt.figure(figsize=(10, 10))
plt.imshow(depth, cmap='gist_rainbow')
plt.colorbar(label='Depth [m]')
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
plt.axis('off') # Hide axis for a cleaner image
# Save the depth map visualization to a file
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
# Save the raw depth data to a CSV file for download
raw_depth_path = "raw_depth_map.csv"
np.savetxt(raw_depth_path, depth, delimiter=',')
# Generate the 3D model from the depth map and resized image
view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px
except Exception as e:
# Return error messages in case of failures
import traceback
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print the full error message to the console
return None, error_message, None, None, None, None, None
finally:
# Clean up by removing the temporary resized image file
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
def get_last_commit_timestamp():
try:
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
print(f"{str(e)}")
return str(e)
# Create the Gradio interface with appropriate input and output components.
last_updated = get_last_commit_timestamp()
with gr.Blocks() as iface:
gr.Markdown("# DepthPro Demo with 3D Visualization")
gr.Markdown(
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
"**Instructions:**\n"
"1. Upload an image.\n"
"2. The app will predict the depth map, display it, and provide the focal length.\n"
"3. Download the raw depth data as a CSV file.\n"
"4. View the generated 3D model textured with the original image.\n"
"5. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n"
"6. Download the 3D model as an OBJ file if desired.\n\n"
f"Last updated: {last_updated}"
)
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image")
depth_map = gr.Image(type="filepath", label="Depth Map")
focal_length = gr.Textbox(label="Focal Length")
raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)")
with gr.Row():
view_3d_model = gr.Model3D(label="View 3D Model")
download_3d_model = gr.File(label="Download 3D Model (OBJ)")
with gr.Row():
simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Simplification Factor")
smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=1, step=1, label="Smoothing Iterations")
thin_threshold = gr.Slider(minimum=0.001, maximum=0.1, value=0.01, step=0.001, label="Thin Feature Threshold")
regenerate_button = gr.Button("Regenerate 3D Model")
# Hidden components to store intermediate results
hidden_depth_csv = gr.State()
hidden_image_path = gr.State()
hidden_focal_length = gr.State()
input_image.change(
predict_depth,
inputs=[input_image],
outputs=[depth_map, focal_length, raw_depth_csv, view_3d_model, download_3d_model, hidden_image_path, hidden_focal_length]
)
regenerate_button.click(
regenerate_3d_model,
inputs=[raw_depth_csv, hidden_image_path, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
outputs=[view_3d_model, download_3d_model]
)
# Launch the Gradio interface with sharing enabled
iface.launch(share=True)