DawnC commited on
Commit
8e90922
1 Parent(s): 14ee6e4

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +25 -49
device_manager.py CHANGED
@@ -2,6 +2,7 @@ from functools import wraps
2
  import torch
3
  import os
4
  import logging
 
5
 
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
@@ -21,60 +22,35 @@ class DeviceManager:
21
 
22
  self._initialized = True
23
  self._current_device = None
24
- self.initialize_zero_gpu()
25
-
26
- def initialize_zero_gpu(self):
27
- """初始化 ZeroGPU"""
28
  try:
29
- # 檢查是否在 Hugging Face Spaces 環境中
30
  if os.environ.get('SPACE_ID'):
31
- # 嘗試初始化 ZeroGPU
32
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
33
- # 設置必要的環境變數
34
- os.environ['ZERO_GPU'] = '1'
35
- logger.info("ZeroGPU environment initialized")
36
- except Exception as e:
37
- logger.warning(f"Failed to initialize ZeroGPU environment: {e}")
38
-
39
- def check_zero_gpu_availability(self):
40
- """檢查 ZeroGPU 是否可用"""
41
- try:
42
- if os.environ.get('SPACE_ID') and os.environ.get('ZERO_GPU') == '1':
43
- # 確保 CUDA 運行時環境正確設置
44
- if torch.cuda.is_available():
45
- torch.cuda.init()
46
- return True
47
  except Exception as e:
48
- logger.warning(f"ZeroGPU check failed: {e}")
49
- return False
50
 
51
  def get_optimal_device(self):
52
- """獲取最佳可用設備"""
53
- if self._current_device is None:
54
- if self.check_zero_gpu_availability():
55
- try:
56
- self._current_device = torch.device('cuda')
57
- logger.info("Using ZeroGPU")
58
- # 嘗試進行一次小規模的 CUDA 操作來驗證
59
- torch.zeros(1).cuda()
60
- except Exception as e:
61
- logger.warning(f"Failed to use ZeroGPU: {e}")
62
- self._current_device = torch.device('cpu')
63
- logger.info("Fallback to CPU")
64
- else:
65
- self._current_device = torch.device('cpu')
66
- logger.info("Using CPU (ZeroGPU not available)")
67
  return self._current_device
68
 
69
- def move_to_device(self, tensor_or_model):
70
- """將張量或模型移動到最佳設備"""
71
- device = self.get_optimal_device()
 
 
72
  try:
73
- if hasattr(tensor_or_model, 'to'):
74
- return tensor_or_model.to(device)
75
- except Exception as e:
76
- logger.warning(f"Failed to move to {device}, falling back to CPU: {e}")
77
- self._current_device = torch.device('cpu')
78
- if hasattr(tensor_or_model, 'to'):
79
- return tensor_or_model.to('cpu')
80
- return tensor_or_model
 
 
2
  import torch
3
  import os
4
  import logging
5
+ import spaces
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
 
22
 
23
  self._initialized = True
24
  self._current_device = None
 
 
 
 
25
  try:
 
26
  if os.environ.get('SPACE_ID'):
27
+ # 使用 spaces 的 GPU wrapper 進行初始化
28
+ @spaces.GPU
29
+ def init_gpu():
30
+ return torch.device('cuda')
31
+ self._current_device = init_gpu()
32
+ logger.info("ZeroGPU initialized successfully")
33
+ else:
34
+ self._current_device = torch.device('cpu')
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
+ logger.warning(f"Failed to initialize ZeroGPU: {e}")
37
+ self._current_device = torch.device('cpu')
38
 
39
  def get_optimal_device(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  return self._current_device
41
 
42
+ def device_handler(func):
43
+ """Decorator for handling device placement with ZeroGPU support"""
44
+ @spaces.GPU
45
+ @wraps(func)
46
+ async def wrapper(*args, **kwargs):
47
  try:
48
+ return await func(*args, **kwargs)
49
+ except RuntimeError as e:
50
+ if "out of memory" in str(e) or "CUDA" in str(e):
51
+ logger.warning("ZeroGPU unavailable, falling back to CPU")
52
+ device_mgr = DeviceManager()
53
+ device_mgr._current_device = torch.device('cpu')
54
+ return await func(*args, **kwargs)
55
+ raise e
56
+ return wrapper