A19grey commited on
Commit
7ce37c4
·
1 Parent(s): 83e6e59

more debugging to address possible GPU OOM or timeout for 3D generation

Browse files
Files changed (1) hide show
  1. app.py +96 -85
app.py CHANGED
@@ -70,95 +70,106 @@ def resize_image(image_path, max_size=1024):
70
  img.save(temp_file, format="PNG")
71
  return temp_file.name
72
 
73
- @spaces.GPU(duration=20)
74
  def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01):
75
  """
76
  Generate a textured 3D mesh from the depth map and the original image.
77
  """
78
- # Load the RGB image and convert to a NumPy array
79
- image = np.array(Image.open(image_path))
80
-
81
- # Ensure depth is a NumPy array
82
- if isinstance(depth, torch.Tensor):
83
- depth = depth.cpu().numpy()
84
-
85
- # Resize depth to match image dimensions if necessary
86
- if depth.shape != image.shape[:2]:
87
- depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
88
-
89
- height, width = depth.shape
90
-
91
- print(f"3D model generation - Depth shape: {depth.shape}")
92
- print(f"3D model generation - Image shape: {image.shape}")
93
-
94
- # Compute camera intrinsic parameters
95
- fx = fy = float(focallength_px) # Ensure focallength_px is a float
96
- cx, cy = width / 2, height / 2 # Principal point at the image center
97
-
98
- # Create a grid of (u, v) pixel coordinates
99
- u = np.arange(0, width)
100
- v = np.arange(0, height)
101
- uu, vv = np.meshgrid(u, v)
102
-
103
- # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
104
- Z = depth.flatten()
105
- X = ((uu.flatten() - cx) * Z) / fx
106
- Y = ((vv.flatten() - cy) * Z) / fy
107
-
108
- # Stack the coordinates to form vertices (X, Y, Z)
109
- vertices = np.vstack((X, Y, Z)).T
110
-
111
- # Normalize RGB colors to [0, 1] for vertex coloring
112
- colors = image.reshape(-1, 3) / 255.0
113
-
114
- # Generate faces by connecting adjacent vertices to form triangles
115
- faces = []
116
- for i in range(height - 1):
117
- for j in range(width - 1):
118
- idx = i * width + j
119
- # Triangle 1
120
- faces.append([idx, idx + width, idx + 1])
121
- # Triangle 2
122
- faces.append([idx + 1, idx + width, idx + width + 1])
123
- faces = np.array(faces)
124
-
125
- # Create the mesh using Trimesh with vertex colors
126
- mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
127
-
128
- # Mesh cleaning and improvement steps
129
- print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
130
-
131
- # 1. Mesh simplification
132
- target_faces = int(len(mesh.faces) * simplification_factor)
133
- mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
134
- print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
135
-
136
- # 2. Remove small disconnected components
137
- components = mesh.split(only_watertight=False)
138
- if len(components) > 1:
139
- areas = np.array([c.area for c in components])
140
- mesh = components[np.argmax(areas)]
141
- print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
142
-
143
- # 3. Smooth the mesh
144
- for _ in range(smoothing_iterations):
145
- mesh = mesh.smoothed()
146
- print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
147
-
148
- # 4. Remove thin features
149
- mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
150
- print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
151
-
152
- # Export the mesh to OBJ files with unique filenames
153
- timestamp = int(time.time())
154
- view_model_path = f'view_model_{timestamp}.obj'
155
- download_model_path = f'download_model_{timestamp}.obj'
156
- print("gonna export to view!")
157
- mesh.export(view_model_path)
158
- print("gonna export to download!")
159
- mesh.export(download_model_path)
160
- print("exported!")
161
- return view_model_path, download_model_path
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def remove_thin_features(mesh, thickness_threshold=0.01):
164
  """
 
70
  img.save(temp_file, format="PNG")
71
  return temp_file.name
72
 
73
+ @spaces.GPU(duration=30) # Increased duration to 30 seconds
74
  def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01):
75
  """
76
  Generate a textured 3D mesh from the depth map and the original image.
77
  """
78
+ try:
79
+ print("Starting 3D model generation")
80
+ # Load the RGB image and convert to a NumPy array
81
+ image = np.array(Image.open(image_path))
82
+
83
+ # Ensure depth is a NumPy array
84
+ if isinstance(depth, torch.Tensor):
85
+ depth = depth.cpu().numpy()
86
+
87
+ # Resize depth to match image dimensions if necessary
88
+ if depth.shape != image.shape[:2]:
89
+ depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
90
+
91
+ height, width = depth.shape
92
+
93
+ print(f"3D model generation - Depth shape: {depth.shape}")
94
+ print(f"3D model generation - Image shape: {image.shape}")
95
+
96
+ # Compute camera intrinsic parameters
97
+ fx = fy = float(focallength_px) # Ensure focallength_px is a float
98
+ cx, cy = width / 2, height / 2 # Principal point at the image center
99
+
100
+ # Create a grid of (u, v) pixel coordinates
101
+ u = np.arange(0, width)
102
+ v = np.arange(0, height)
103
+ uu, vv = np.meshgrid(u, v)
104
+
105
+ # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
106
+ Z = depth.flatten()
107
+ X = ((uu.flatten() - cx) * Z) / fx
108
+ Y = ((vv.flatten() - cy) * Z) / fy
109
+
110
+ # Stack the coordinates to form vertices (X, Y, Z)
111
+ vertices = np.vstack((X, Y, Z)).T
112
+
113
+ # Normalize RGB colors to [0, 1] for vertex coloring
114
+ colors = image.reshape(-1, 3) / 255.0
115
+
116
+ print("Generating faces")
117
+ # Generate faces by connecting adjacent vertices to form triangles
118
+ faces = []
119
+ for i in range(height - 1):
120
+ for j in range(width - 1):
121
+ idx = i * width + j
122
+ # Triangle 1
123
+ faces.append([idx, idx + width, idx + 1])
124
+ # Triangle 2
125
+ faces.append([idx + 1, idx + width, idx + width + 1])
126
+ faces = np.array(faces)
127
+
128
+ print("Creating mesh")
129
+ # Create the mesh using Trimesh with vertex colors
130
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
131
+
132
+ # Mesh cleaning and improvement steps
133
+ print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
134
+
135
+ print("Simplifying mesh")
136
+ # 1. Mesh simplification
137
+ target_faces = int(len(mesh.faces) * simplification_factor)
138
+ mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
139
+ print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
140
+
141
+ print("Removing small components")
142
+ # 2. Remove small disconnected components
143
+ components = mesh.split(only_watertight=False)
144
+ if len(components) > 1:
145
+ areas = np.array([c.area for c in components])
146
+ mesh = components[np.argmax(areas)]
147
+ print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
148
+
149
+ print("Smoothing mesh")
150
+ # 3. Smooth the mesh
151
+ for _ in range(smoothing_iterations):
152
+ mesh = mesh.smoothed()
153
+ print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
154
+
155
+ print("Removing thin features")
156
+ # 4. Remove thin features
157
+ mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
158
+ print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
159
+
160
+ # Export the mesh to OBJ files with unique filenames
161
+ timestamp = int(time.time())
162
+ view_model_path = f'view_model_{timestamp}.obj'
163
+ download_model_path = f'download_model_{timestamp}.obj'
164
+ print("Exporting to view")
165
+ mesh.export(view_model_path)
166
+ print("Exporting to download")
167
+ mesh.export(download_model_path)
168
+ print("Export completed")
169
+ return view_model_path, download_model_path
170
+ except Exception as e:
171
+ print(f"Error in generate_3d_model: {str(e)}")
172
+ raise
173
 
174
  def remove_thin_features(mesh, thickness_threshold=0.01):
175
  """