lyndonzheng commited on
Commit
a19de17
Β·
1 Parent(s): 4bba5c3

fixed version bug

Browse files
app.py CHANGED
@@ -26,9 +26,9 @@ def main():
26
 
27
  cfg = OmegaConf.load(model_cfg_path)
28
  model = GaussianPredictor(cfg)
29
- device = torch.device("cuda:0")
30
- model.to(device)
31
  model.load_model(model_path)
 
32
 
33
  pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
34
  to_tensor = TT.ToTensor()
@@ -150,7 +150,7 @@ def main():
150
  )
151
 
152
  demo.queue(max_size=1)
153
- demo.launch(share=True)
154
 
155
 
156
  if __name__ == "__main__":
 
26
 
27
  cfg = OmegaConf.load(model_cfg_path)
28
  model = GaussianPredictor(cfg)
29
+ device = torch.device(device)
 
30
  model.load_model(model_path)
31
+ model.to(device)
32
 
33
  pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
34
  to_tensor = TT.ToTensor()
 
150
  )
151
 
152
  demo.queue(max_size=1)
153
+ demo.launch()
154
 
155
 
156
  if __name__ == "__main__":
flash3d/networks/gaussian_predictor.py CHANGED
@@ -235,7 +235,7 @@ class GaussianPredictor(nn.Module):
235
  for ckpt in ckpts[num_ckpts:]:
236
  ckpt.unlink()
237
 
238
- def load_model(self, weights_path, optimizer=None):
239
  """Load model(s) from disk
240
  """
241
  weights_path = Path(weights_path)
@@ -246,7 +246,7 @@ class GaussianPredictor(nn.Module):
246
  return
247
 
248
  logging.info(f"Loading weights from {weights_path}...")
249
- state_dict = torch.load(weights_path)
250
  if "version" in state_dict and state_dict["version"] == "1.0":
251
  new_dict = {}
252
  for k, v in state_dict["model"].items():
 
235
  for ckpt in ckpts[num_ckpts:]:
236
  ckpt.unlink()
237
 
238
+ def load_model(self, weights_path, optimizer=None, device='cpu'):
239
  """Load model(s) from disk
240
  """
241
  weights_path = Path(weights_path)
 
246
  return
247
 
248
  logging.info(f"Loading weights from {weights_path}...")
249
+ state_dict = torch.load(weights_path, map_location=torch.device(device))
250
  if "version" in state_dict and state_dict["version"] == "1.0":
251
  new_dict = {}
252
  for k, v in state_dict["model"].items():
flash3d/unidepth/utils/geometric.py CHANGED
@@ -48,7 +48,6 @@ def generate_rays(
48
  return ray_directions, angles
49
 
50
 
51
- @torch.jit.script
52
  def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
53
  theta = spherical_tensor[..., 0] # Extract polar angle
54
  phi = spherical_tensor[..., 1] # Extract azimuthal angle
@@ -68,7 +67,6 @@ def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tens
68
  return euclidean_tensor
69
 
70
 
71
- @torch.jit.script
72
  def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
73
  theta = spherical_tensor[..., 0] # Extract polar angle
74
  phi = spherical_tensor[..., 1] # Extract azimuthal angle
@@ -84,7 +82,6 @@ def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
84
  return euclidean_tensor
85
 
86
 
87
- @torch.jit.script
88
  def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
89
  x = spherical_tensor[..., 0] # Extract polar angle
90
  y = spherical_tensor[..., 1] # Extract azimuthal angle
@@ -100,7 +97,6 @@ def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
100
  return euclidean_tensor
101
 
102
 
103
- @torch.jit.script
104
  def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
105
  pitch = torch.asin(euclidean_tensor[..., 1])
106
  yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
@@ -109,7 +105,6 @@ def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tens
109
  return euclidean_tensor
110
 
111
 
112
- @torch.jit.script
113
  def unproject_points(
114
  depth: torch.Tensor, camera_intrinsics: torch.Tensor
115
  ) -> torch.Tensor:
@@ -152,7 +147,6 @@ def unproject_points(
152
  return unprojected_points
153
 
154
 
155
- @torch.jit.script
156
  def project_points(
157
  points_3d: torch.Tensor,
158
  intrinsic_matrix: torch.Tensor,
@@ -200,7 +194,6 @@ def project_points(
200
  return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
201
 
202
 
203
- @torch.jit.script
204
  def downsample(data: torch.Tensor, downsample_factor: int = 2):
205
  N, _, H, W = data.shape
206
  data = data.view(
@@ -220,7 +213,6 @@ def downsample(data: torch.Tensor, downsample_factor: int = 2):
220
  return data
221
 
222
 
223
- @torch.jit.script
224
  def flat_interpolate(
225
  flat_tensor: torch.Tensor,
226
  old: Tuple[int, int],
 
48
  return ray_directions, angles
49
 
50
 
 
51
  def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
52
  theta = spherical_tensor[..., 0] # Extract polar angle
53
  phi = spherical_tensor[..., 1] # Extract azimuthal angle
 
67
  return euclidean_tensor
68
 
69
 
 
70
  def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
71
  theta = spherical_tensor[..., 0] # Extract polar angle
72
  phi = spherical_tensor[..., 1] # Extract azimuthal angle
 
82
  return euclidean_tensor
83
 
84
 
 
85
  def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
86
  x = spherical_tensor[..., 0] # Extract polar angle
87
  y = spherical_tensor[..., 1] # Extract azimuthal angle
 
97
  return euclidean_tensor
98
 
99
 
 
100
  def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
101
  pitch = torch.asin(euclidean_tensor[..., 1])
102
  yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
 
105
  return euclidean_tensor
106
 
107
 
 
108
  def unproject_points(
109
  depth: torch.Tensor, camera_intrinsics: torch.Tensor
110
  ) -> torch.Tensor:
 
147
  return unprojected_points
148
 
149
 
 
150
  def project_points(
151
  points_3d: torch.Tensor,
152
  intrinsic_matrix: torch.Tensor,
 
194
  return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
195
 
196
 
 
197
  def downsample(data: torch.Tensor, downsample_factor: int = 2):
198
  N, _, H, W = data.shape
199
  data = data.view(
 
213
  return data
214
 
215
 
 
216
  def flat_interpolate(
217
  flat_tensor: torch.Tensor,
218
  old: Tuple[int, int],
pre-requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu118
2
- torch==2.2.2
3
  torchvision
4
  torchaudio
5
  xformers==0.0.25.post1
 
1
  --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch
3
  torchvision
4
  torchaudio
5
  xformers==0.0.25.post1
requirements.txt CHANGED
@@ -13,4 +13,5 @@ plyfile
13
  omegaconf
14
  jaxtyping
15
  gradio
16
- spaces
 
 
13
  omegaconf
14
  jaxtyping
15
  gradio
16
+ spaces
17
+ opencv-python