DawnC commited on
Commit
c16c12e
1 Parent(s): 7a79c1c

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +11 -54
device_manager.py CHANGED
@@ -27,20 +27,12 @@ class DeviceManager:
27
 
28
  def check_zero_gpu_availability(self):
29
  try:
30
- # 檢查是否在 Spaces 環境中
31
- if 'SPACE_ID' not in os.environ:
32
- return False
33
-
34
- # 檢查是否為 Pro 用戶(ZeroGPU 可用)
35
- api = HfApi()
36
- space_info = api.get_space_runtime(os.environ['SPACE_ID'])
37
-
38
- # 檢查是否有 ZeroGPU 資源
39
- if (hasattr(space_info, 'hardware') and
40
- space_info.hardware.get('zerogpu', False)):
41
- self._zero_gpu_available = True
42
- return True
43
-
44
  except Exception as e:
45
  logger.warning(f"Error checking ZeroGPU availability: {e}")
46
 
@@ -51,63 +43,28 @@ class DeviceManager:
51
  if self._current_device is None:
52
  if self.check_zero_gpu_availability():
53
  try:
54
- # 特別標記這是 ZeroGPU 環境
55
- os.environ['ZERO_GPU'] = '1'
56
  self._current_device = torch.device('cuda')
57
  logger.info("Using ZeroGPU")
58
  except Exception as e:
59
  logger.warning(f"Failed to initialize ZeroGPU: {e}")
60
  self._current_device = torch.device('cpu')
61
- logger.info("Fallback to CPU due to ZeroGPU initialization failure")
62
  else:
63
  self._current_device = torch.device('cpu')
64
- logger.info("Using CPU (ZeroGPU not available)")
65
  return self._current_device
66
-
67
- def move_to_device(self, tensor_or_model):
68
- device = self.get_optimal_device()
69
- if hasattr(tensor_or_model, 'to'):
70
- try:
71
- return tensor_or_model.to(device)
72
- except Exception as e:
73
- logger.warning(f"Failed to move tensor/model to {device}: {e}")
74
- self._current_device = torch.device('cpu')
75
- return tensor_or_model.to('cpu')
76
- return tensor_or_model
77
 
78
  def device_handler(func):
79
- """Decorator for handling device placement with ZeroGPU support"""
80
  @wraps(func)
81
  async def wrapper(*args, **kwargs):
82
  device_mgr = DeviceManager()
83
-
84
- def process_arg(arg):
85
- if torch.is_tensor(arg) or hasattr(arg, 'to'):
86
- return device_mgr.move_to_device(arg)
87
- return arg
88
-
89
- processed_args = [process_arg(arg) for arg in args]
90
- processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()}
91
-
92
  try:
93
- # 如果函數是異步的,使用 await
94
- if asyncio.iscoroutinefunction(func):
95
- result = await func(*processed_args, **processed_kwargs)
96
- else:
97
- result = func(*processed_args, **processed_kwargs)
98
-
99
- # 處理輸出
100
- if torch.is_tensor(result):
101
- return device_mgr.move_to_device(result)
102
- elif isinstance(result, tuple):
103
- return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result)
104
  return result
105
-
106
  except RuntimeError as e:
107
  if "out of memory" in str(e) or "CUDA" in str(e):
108
- logger.warning("ZeroGPU resources unavailable, falling back to CPU")
109
  device_mgr._current_device = torch.device('cpu')
110
- return await wrapper(*args, **kwargs)
111
  raise e
112
-
113
  return wrapper
 
27
 
28
  def check_zero_gpu_availability(self):
29
  try:
30
+ if 'SPACE_ID' in os.environ:
31
+ api = HfApi()
32
+ space_info = api.get_space_runtime(os.environ['SPACE_ID'])
33
+ if hasattr(space_info, 'hardware') and space_info.hardware.get('zerogpu', False):
34
+ self._zero_gpu_available = True
35
+ return True
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
  logger.warning(f"Error checking ZeroGPU availability: {e}")
38
 
 
43
  if self._current_device is None:
44
  if self.check_zero_gpu_availability():
45
  try:
 
 
46
  self._current_device = torch.device('cuda')
47
  logger.info("Using ZeroGPU")
48
  except Exception as e:
49
  logger.warning(f"Failed to initialize ZeroGPU: {e}")
50
  self._current_device = torch.device('cpu')
 
51
  else:
52
  self._current_device = torch.device('cpu')
53
+ logger.info("Using CPU")
54
  return self._current_device
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def device_handler(func):
57
+ """簡化版的 device handler"""
58
  @wraps(func)
59
  async def wrapper(*args, **kwargs):
60
  device_mgr = DeviceManager()
 
 
 
 
 
 
 
 
 
61
  try:
62
+ result = await func(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
63
  return result
 
64
  except RuntimeError as e:
65
  if "out of memory" in str(e) or "CUDA" in str(e):
66
+ logger.warning("ZeroGPU unavailable, falling back to CPU")
67
  device_mgr._current_device = torch.device('cpu')
68
+ return await func(*args, **kwargs)
69
  raise e
 
70
  return wrapper