A19grey commited on
Commit
b3b839e
·
1 Parent(s): 29a026e

moved resizing to 3D model code out of depth generatino to clean architecture

Browse files
Files changed (1) hide show
  1. app.py +17 -32
app.py CHANGED
@@ -60,7 +60,7 @@ def generate_3d_model(depth, image_path, focallength_px):
60
 
61
  Args:
62
  depth (np.ndarray): 2D array representing depth in meters.
63
- image_path (str): Path to the resized RGB image.
64
  focallength_px (float): Focal length in pixels.
65
 
66
  Returns:
@@ -68,8 +68,16 @@ def generate_3d_model(depth, image_path, focallength_px):
68
  """
69
  # Load the RGB image and convert to a NumPy array
70
  image = np.array(Image.open(image_path))
 
 
 
 
 
71
  height, width = depth.shape
72
 
 
 
 
73
  # Compute camera intrinsic parameters
74
  fx = fy = focallength_px # Assuming square pixels and fx = fy
75
  cx, cy = width / 2, height / 2 # Principal point at the image center
@@ -126,17 +134,13 @@ def predict_depth(input_image):
126
  # Preprocess the image for depth prediction
127
  result = depth_pro.load_rgb(temp_file)
128
 
129
- # Add error checking for the result tuple
130
  if len(result) < 2:
131
  raise ValueError(f"Unexpected result from load_rgb: {result}")
132
 
133
- image = result[0] # Unpack the result tuple correctly
134
- f_px = result[-1] # Extract focal length
135
-
136
  print(f"Extracted focal length: {f_px}")
137
 
138
- image = transform(image) # Apply preprocessing transforms
139
- image = image.to(device) # Move the image tensor to the selected device
140
 
141
  # Run the depth prediction model
142
  prediction = model.infer(image, f_px=f_px)
@@ -151,33 +155,13 @@ def predict_depth(input_image):
151
  if depth.ndim != 2:
152
  depth = depth.squeeze()
153
 
154
- # Print debug information
155
- print(f"Original depth shape: {depth.shape}")
156
- print(f"Original image shape: {image.shape}")
157
-
158
- # Resize depth to match image dimensions
159
- image_height, image_width = image.shape[2], image.shape[3]
160
- depth = cv2.resize(depth, (image_width, image_height), interpolation=cv2.INTER_LINEAR)
161
-
162
- print(f"Resized depth shape: {depth.shape}")
163
- print(f"Final image shape: {image.shape}")
164
-
165
- # No downsampling
166
- downscale_factor = 1
167
-
168
- # Convert image tensor to CPU and NumPy
169
- image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)
170
-
171
- # No normalization of depth map as it is already in meters
172
- depth_min = np.min(depth)
173
- depth_max = np.max(depth)
174
- depth_normalized = depth # Depth remains in meters
175
 
176
  # Create a color map for visualization using matplotlib
177
  plt.figure(figsize=(10, 10))
178
- plt.imshow(depth_normalized, cmap='gist_rainbow')
179
  plt.colorbar(label='Depth [m]')
180
- plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
181
  plt.axis('off') # Hide axis for a cleaner image
182
 
183
  # Save the depth map visualization to a file
@@ -208,8 +192,9 @@ def get_last_commit_timestamp():
208
  try:
209
  timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
210
  return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
211
- except Exception:
212
- return "Unknown"
 
213
 
214
  # Create the Gradio interface with appropriate input and output components.
215
  last_updated = get_last_commit_timestamp()
 
60
 
61
  Args:
62
  depth (np.ndarray): 2D array representing depth in meters.
63
+ image_path (str): Path to the RGB image.
64
  focallength_px (float): Focal length in pixels.
65
 
66
  Returns:
 
68
  """
69
  # Load the RGB image and convert to a NumPy array
70
  image = np.array(Image.open(image_path))
71
+
72
+ # Resize depth to match image dimensions if necessary
73
+ if depth.shape != image.shape[:2]:
74
+ depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
75
+
76
  height, width = depth.shape
77
 
78
+ print(f"3D model generation - Depth shape: {depth.shape}")
79
+ print(f"3D model generation - Image shape: {image.shape}")
80
+
81
  # Compute camera intrinsic parameters
82
  fx = fy = focallength_px # Assuming square pixels and fx = fy
83
  cx, cy = width / 2, height / 2 # Principal point at the image center
 
134
  # Preprocess the image for depth prediction
135
  result = depth_pro.load_rgb(temp_file)
136
 
 
137
  if len(result) < 2:
138
  raise ValueError(f"Unexpected result from load_rgb: {result}")
139
 
140
+ image, _, _, _, f_px = result
 
 
141
  print(f"Extracted focal length: {f_px}")
142
 
143
+ image = transform(image).to(device)
 
144
 
145
  # Run the depth prediction model
146
  prediction = model.infer(image, f_px=f_px)
 
155
  if depth.ndim != 2:
156
  depth = depth.squeeze()
157
 
158
+ print(f"Depth map shape: {depth.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # Create a color map for visualization using matplotlib
161
  plt.figure(figsize=(10, 10))
162
+ plt.imshow(depth, cmap='gist_rainbow')
163
  plt.colorbar(label='Depth [m]')
164
+ plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
165
  plt.axis('off') # Hide axis for a cleaner image
166
 
167
  # Save the depth map visualization to a file
 
192
  try:
193
  timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
194
  return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
195
+ except Exception as e:
196
+ print(f"{str(e)}")
197
+ return str(e)
198
 
199
  # Create the Gradio interface with appropriate input and output components.
200
  last_updated = get_last_commit_timestamp()