DawnC commited on
Commit
1959409
1 Parent(s): 43bd720

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +25 -25
device_manager.py CHANGED
@@ -1,8 +1,6 @@
1
- from functools import wraps
2
  import torch
3
  import os
4
  import logging
5
- import spaces
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
@@ -22,35 +20,37 @@ class DeviceManager:
22
 
23
  self._initialized = True
24
  self._current_device = None
 
 
 
25
  try:
26
  if os.environ.get('SPACE_ID'):
27
- # 使用 spaces 的 GPU wrapper 進行初始化
28
- @spaces.GPU
29
- def init_gpu():
30
- return torch.device('cuda')
31
- self._current_device = init_gpu()
32
- logger.info("ZeroGPU initialized successfully")
 
 
33
  else:
34
- self._current_device = torch.device('cpu')
35
  except Exception as e:
36
- logger.warning(f"Failed to initialize ZeroGPU: {e}")
37
  self._current_device = torch.device('cpu')
38
 
39
  def get_optimal_device(self):
 
 
40
  return self._current_device
41
 
42
- def device_handler(func):
43
- """Decorator for handling device placement with ZeroGPU support"""
44
- @spaces.GPU
45
- @wraps(func)
46
- async def wrapper(*args, **kwargs):
47
- try:
48
- return await func(*args, **kwargs)
49
- except RuntimeError as e:
50
- if "out of memory" in str(e) or "CUDA" in str(e):
51
- logger.warning("ZeroGPU unavailable, falling back to CPU")
52
- device_mgr = DeviceManager()
53
- device_mgr._current_device = torch.device('cpu')
54
- return await func(*args, **kwargs)
55
- raise e
56
- return wrapper
 
 
1
  import torch
2
  import os
3
  import logging
 
4
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
 
20
 
21
  self._initialized = True
22
  self._current_device = None
23
+ self.initialize_device()
24
+
25
+ def initialize_device(self):
26
  try:
27
  if os.environ.get('SPACE_ID'):
28
+ # 嘗試初始化 CUDA 設備
29
+ if torch.cuda.is_available():
30
+ self._current_device = torch.device('cuda')
31
+ # 設置 CUDA 設備為可見
32
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
33
+ logger.info("CUDA device initialized successfully")
34
+ else:
35
+ raise RuntimeError("CUDA not available")
36
  else:
37
+ raise RuntimeError("Not in Spaces environment")
38
  except Exception as e:
39
+ logger.warning(f"Using CPU due to: {e}")
40
  self._current_device = torch.device('cpu')
41
 
42
  def get_optimal_device(self):
43
+ if self._current_device is None:
44
+ self.initialize_device()
45
  return self._current_device
46
 
47
+ def to_device(tensor_or_model, device=None):
48
+ """Helper function to move tensors or models to the appropriate device"""
49
+ if device is None:
50
+ device = DeviceManager().get_optimal_device()
51
+
52
+ try:
53
+ return tensor_or_model.to(device)
54
+ except Exception as e:
55
+ logger.warning(f"Failed to move to {device}, using CPU: {e}")
56
+ return tensor_or_model.to('cpu')