aiqtech commited on
Commit
b31b828
β€’
1 Parent(s): 4ec8a28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -347,7 +347,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
347
  # GLB λ³€ν™˜
348
  with torch.inference_mode():
349
  try:
350
- # λͺ¨λ“  ν…μ„œλ₯Ό CUDA둜 μ΄λ™ν•˜κ³  gradient ν™œμ„±ν™”
351
  device = torch.device('cuda:0')
352
 
353
  # Gaussian ν…μ„œλ“€μ„ λ³€ν™˜
@@ -355,24 +355,23 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
355
  if hasattr(gs, attr_name):
356
  tensor = getattr(gs, attr_name)
357
  if torch.is_tensor(tensor):
358
- # gradient 계산이 ν•„μš”ν•œ ν…μ„œλ‘œ λ³€ν™˜
359
- new_tensor = tensor.detach().clone().float().to(device).requires_grad_(True)
360
  setattr(gs, attr_name, new_tensor)
361
 
362
  # Mesh ν…μ„œλ“€μ„ λ³€ν™˜
363
  if hasattr(mesh, 'vertices') and torch.is_tensor(mesh.vertices):
364
- mesh.vertices = mesh.vertices.detach().clone().float().to(device).requires_grad_(True)
365
  if hasattr(mesh, 'faces') and torch.is_tensor(mesh.faces):
366
  mesh.faces = mesh.faces.detach().clone().long().to(device)
367
 
368
- # μΆ”κ°€ 속성 확인 및 λ³€ν™˜
369
  for attr_name in dir(mesh):
370
  if attr_name.startswith('_'):
371
  continue
372
  attr = getattr(mesh, attr_name)
373
  if torch.is_tensor(attr):
374
  if attr.dtype in [torch.float32, torch.float64]:
375
- setattr(mesh, attr_name, attr.to(device).requires_grad_(True))
376
  else:
377
  setattr(mesh, attr_name, attr.to(device))
378
 
@@ -425,6 +424,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
425
  clear_gpu_memory()
426
 
427
 
 
428
  def activate_button() -> gr.Button:
429
  return gr.Button(interactive=True)
430
 
 
347
  # GLB λ³€ν™˜
348
  with torch.inference_mode():
349
  try:
350
+ # λͺ¨λ“  ν…μ„œλ₯Ό CUDA둜 이동 (gradient λΆˆν•„μš”)
351
  device = torch.device('cuda:0')
352
 
353
  # Gaussian ν…μ„œλ“€μ„ λ³€ν™˜
 
355
  if hasattr(gs, attr_name):
356
  tensor = getattr(gs, attr_name)
357
  if torch.is_tensor(tensor):
358
+ new_tensor = tensor.detach().clone().float().to(device)
 
359
  setattr(gs, attr_name, new_tensor)
360
 
361
  # Mesh ν…μ„œλ“€μ„ λ³€ν™˜
362
  if hasattr(mesh, 'vertices') and torch.is_tensor(mesh.vertices):
363
+ mesh.vertices = mesh.vertices.detach().clone().float().to(device)
364
  if hasattr(mesh, 'faces') and torch.is_tensor(mesh.faces):
365
  mesh.faces = mesh.faces.detach().clone().long().to(device)
366
 
367
+ # μΆ”κ°€ 속성 확인 및 λ³€ν™˜ (gradient λΆˆν•„μš”)
368
  for attr_name in dir(mesh):
369
  if attr_name.startswith('_'):
370
  continue
371
  attr = getattr(mesh, attr_name)
372
  if torch.is_tensor(attr):
373
  if attr.dtype in [torch.float32, torch.float64]:
374
+ setattr(mesh, attr_name, attr.to(device))
375
  else:
376
  setattr(mesh, attr_name, attr.to(device))
377
 
 
424
  clear_gpu_memory()
425
 
426
 
427
+
428
  def activate_button() -> gr.Button:
429
  return gr.Button(interactive=True)
430