aiqtech commited on
Commit
4ec8a28
β€’
1 Parent(s): 0ad83b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -347,7 +347,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
347
  # GLB λ³€ν™˜
348
  with torch.inference_mode():
349
  try:
350
- # λͺ¨λ“  ν…μ„œλ₯Ό CUDA둜 이동
351
  device = torch.device('cuda:0')
352
 
353
  # Gaussian ν…μ„œλ“€μ„ λ³€ν™˜
@@ -355,12 +355,13 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
355
  if hasattr(gs, attr_name):
356
  tensor = getattr(gs, attr_name)
357
  if torch.is_tensor(tensor):
358
- new_tensor = tensor.detach().clone().float().to(device)
 
359
  setattr(gs, attr_name, new_tensor)
360
 
361
  # Mesh ν…μ„œλ“€μ„ λ³€ν™˜
362
  if hasattr(mesh, 'vertices') and torch.is_tensor(mesh.vertices):
363
- mesh.vertices = mesh.vertices.detach().clone().float().to(device)
364
  if hasattr(mesh, 'faces') and torch.is_tensor(mesh.faces):
365
  mesh.faces = mesh.faces.detach().clone().long().to(device)
366
 
@@ -370,11 +371,14 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
370
  continue
371
  attr = getattr(mesh, attr_name)
372
  if torch.is_tensor(attr):
373
- setattr(mesh, attr_name, attr.to(device))
 
 
 
374
 
375
- print("Device check before GLB conversion:")
376
- print(f"Gaussian xyz device: {gs._xyz.device}")
377
- print(f"Mesh vertices device: {mesh.vertices.device}")
378
 
379
  # GLB λ³€ν™˜
380
  glb = postprocessing_utils.to_glb(
@@ -387,11 +391,11 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
387
 
388
  except Exception as e:
389
  print(f"Error during GLB conversion: {str(e)}")
390
- # λ””λ°”μ΄μŠ€ 정보 좜λ ₯
391
  if hasattr(gs, '_xyz'):
392
- print(f"Gaussian xyz device: {gs._xyz.device}")
393
  if hasattr(mesh, 'vertices'):
394
- print(f"Mesh vertices device: {mesh.vertices.device}")
395
  return None, None
396
 
397
  if glb is None:
 
347
  # GLB λ³€ν™˜
348
  with torch.inference_mode():
349
  try:
350
+ # λͺ¨λ“  ν…μ„œλ₯Ό CUDA둜 μ΄λ™ν•˜κ³  gradient ν™œμ„±ν™”
351
  device = torch.device('cuda:0')
352
 
353
  # Gaussian ν…μ„œλ“€μ„ λ³€ν™˜
 
355
  if hasattr(gs, attr_name):
356
  tensor = getattr(gs, attr_name)
357
  if torch.is_tensor(tensor):
358
+ # gradient 계산이 ν•„μš”ν•œ ν…μ„œλ‘œ λ³€ν™˜
359
+ new_tensor = tensor.detach().clone().float().to(device).requires_grad_(True)
360
  setattr(gs, attr_name, new_tensor)
361
 
362
  # Mesh ν…μ„œλ“€μ„ λ³€ν™˜
363
  if hasattr(mesh, 'vertices') and torch.is_tensor(mesh.vertices):
364
+ mesh.vertices = mesh.vertices.detach().clone().float().to(device).requires_grad_(True)
365
  if hasattr(mesh, 'faces') and torch.is_tensor(mesh.faces):
366
  mesh.faces = mesh.faces.detach().clone().long().to(device)
367
 
 
371
  continue
372
  attr = getattr(mesh, attr_name)
373
  if torch.is_tensor(attr):
374
+ if attr.dtype in [torch.float32, torch.float64]:
375
+ setattr(mesh, attr_name, attr.to(device).requires_grad_(True))
376
+ else:
377
+ setattr(mesh, attr_name, attr.to(device))
378
 
379
+ print("Device and gradient check before GLB conversion:")
380
+ print(f"Gaussian xyz device: {gs._xyz.device}, requires_grad: {gs._xyz.requires_grad}")
381
+ print(f"Mesh vertices device: {mesh.vertices.device}, requires_grad: {mesh.vertices.requires_grad}")
382
 
383
  # GLB λ³€ν™˜
384
  glb = postprocessing_utils.to_glb(
 
391
 
392
  except Exception as e:
393
  print(f"Error during GLB conversion: {str(e)}")
394
+ # λ””λ°”μ΄μŠ€μ™€ gradient 정보 좜λ ₯
395
  if hasattr(gs, '_xyz'):
396
+ print(f"Gaussian xyz device: {gs._xyz.device}, requires_grad: {gs._xyz.requires_grad}")
397
  if hasattr(mesh, 'vertices'):
398
+ print(f"Mesh vertices device: {mesh.vertices.device}, requires_grad: {mesh.vertices.requires_grad}")
399
  return None, None
400
 
401
  if glb is None: