aiqtech commited on
Commit
59ca7f8
·
verified ·
1 Parent(s): ab95920

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -11
app.py CHANGED
@@ -38,7 +38,6 @@ def initialize_models(device):
38
  g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
39
  "JeffreyXiang/TRELLIS-image-large"
40
  )
41
- g.trellis_pipeline.to(device)
42
 
43
  # 이미지 생성 파이프라인
44
  g.flux_pipe = FluxPipeline.from_pretrained(
@@ -69,7 +68,6 @@ torch.backends.cuda.matmul.allow_tf32 = True
69
  torch.backends.cudnn.benchmark = True
70
 
71
  # 환경 변수 설정
72
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
73
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
74
  os.environ['SPCONV_ALGO'] = 'native'
75
  os.environ['SPARSE_BACKEND'] = 'native'
@@ -422,19 +420,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
422
  )
423
 
424
  if __name__ == "__main__":
425
- # CUDA 사용 가능 여부 확인
426
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
427
- print(f"Using device: {device}")
428
-
429
  try:
 
 
 
 
430
  # 모델 초기화
431
  initialize_models(device)
432
 
433
- # CUDA 메모리 초기화
434
- if torch.cuda.is_available():
435
- torch.cuda.empty_cache()
436
- torch.cuda.synchronize()
437
-
438
  # 초기 이미지 전처리 테스트
439
  try:
440
  test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
 
38
  g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
39
  "JeffreyXiang/TRELLIS-image-large"
40
  )
 
41
 
42
  # 이미지 생성 파이프라인
43
  g.flux_pipe = FluxPipeline.from_pretrained(
 
68
  torch.backends.cudnn.benchmark = True
69
 
70
  # 환경 변수 설정
 
71
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
72
  os.environ['SPCONV_ALGO'] = 'native'
73
  os.environ['SPARSE_BACKEND'] = 'native'
 
420
  )
421
 
422
  if __name__ == "__main__":
 
 
 
 
423
  try:
424
+ # CPU 모드로 초기화
425
+ device = "cpu"
426
+ print(f"Using device: {device}")
427
+
428
  # 모델 초기화
429
  initialize_models(device)
430
 
 
 
 
 
 
431
  # 초기 이미지 전처리 테스트
432
  try:
433
  test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))