Spaces:
Running
on
Zero
Running
on
Zero
Update device_manager.py
Browse files- device_manager.py +11 -54
device_manager.py
CHANGED
@@ -27,20 +27,12 @@ class DeviceManager:
|
|
27 |
|
28 |
def check_zero_gpu_availability(self):
|
29 |
try:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
space_info = api.get_space_runtime(os.environ['SPACE_ID'])
|
37 |
-
|
38 |
-
# 檢查是否有 ZeroGPU 資源
|
39 |
-
if (hasattr(space_info, 'hardware') and
|
40 |
-
space_info.hardware.get('zerogpu', False)):
|
41 |
-
self._zero_gpu_available = True
|
42 |
-
return True
|
43 |
-
|
44 |
except Exception as e:
|
45 |
logger.warning(f"Error checking ZeroGPU availability: {e}")
|
46 |
|
@@ -51,63 +43,28 @@ class DeviceManager:
|
|
51 |
if self._current_device is None:
|
52 |
if self.check_zero_gpu_availability():
|
53 |
try:
|
54 |
-
# 特別標記這是 ZeroGPU 環境
|
55 |
-
os.environ['ZERO_GPU'] = '1'
|
56 |
self._current_device = torch.device('cuda')
|
57 |
logger.info("Using ZeroGPU")
|
58 |
except Exception as e:
|
59 |
logger.warning(f"Failed to initialize ZeroGPU: {e}")
|
60 |
self._current_device = torch.device('cpu')
|
61 |
-
logger.info("Fallback to CPU due to ZeroGPU initialization failure")
|
62 |
else:
|
63 |
self._current_device = torch.device('cpu')
|
64 |
-
logger.info("Using CPU
|
65 |
return self._current_device
|
66 |
-
|
67 |
-
def move_to_device(self, tensor_or_model):
|
68 |
-
device = self.get_optimal_device()
|
69 |
-
if hasattr(tensor_or_model, 'to'):
|
70 |
-
try:
|
71 |
-
return tensor_or_model.to(device)
|
72 |
-
except Exception as e:
|
73 |
-
logger.warning(f"Failed to move tensor/model to {device}: {e}")
|
74 |
-
self._current_device = torch.device('cpu')
|
75 |
-
return tensor_or_model.to('cpu')
|
76 |
-
return tensor_or_model
|
77 |
|
78 |
def device_handler(func):
|
79 |
-
"""
|
80 |
@wraps(func)
|
81 |
async def wrapper(*args, **kwargs):
|
82 |
device_mgr = DeviceManager()
|
83 |
-
|
84 |
-
def process_arg(arg):
|
85 |
-
if torch.is_tensor(arg) or hasattr(arg, 'to'):
|
86 |
-
return device_mgr.move_to_device(arg)
|
87 |
-
return arg
|
88 |
-
|
89 |
-
processed_args = [process_arg(arg) for arg in args]
|
90 |
-
processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()}
|
91 |
-
|
92 |
try:
|
93 |
-
|
94 |
-
if asyncio.iscoroutinefunction(func):
|
95 |
-
result = await func(*processed_args, **processed_kwargs)
|
96 |
-
else:
|
97 |
-
result = func(*processed_args, **processed_kwargs)
|
98 |
-
|
99 |
-
# 處理輸出
|
100 |
-
if torch.is_tensor(result):
|
101 |
-
return device_mgr.move_to_device(result)
|
102 |
-
elif isinstance(result, tuple):
|
103 |
-
return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result)
|
104 |
return result
|
105 |
-
|
106 |
except RuntimeError as e:
|
107 |
if "out of memory" in str(e) or "CUDA" in str(e):
|
108 |
-
logger.warning("ZeroGPU
|
109 |
device_mgr._current_device = torch.device('cpu')
|
110 |
-
return await
|
111 |
raise e
|
112 |
-
|
113 |
return wrapper
|
|
|
27 |
|
28 |
def check_zero_gpu_availability(self):
|
29 |
try:
|
30 |
+
if 'SPACE_ID' in os.environ:
|
31 |
+
api = HfApi()
|
32 |
+
space_info = api.get_space_runtime(os.environ['SPACE_ID'])
|
33 |
+
if hasattr(space_info, 'hardware') and space_info.hardware.get('zerogpu', False):
|
34 |
+
self._zero_gpu_available = True
|
35 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception as e:
|
37 |
logger.warning(f"Error checking ZeroGPU availability: {e}")
|
38 |
|
|
|
43 |
if self._current_device is None:
|
44 |
if self.check_zero_gpu_availability():
|
45 |
try:
|
|
|
|
|
46 |
self._current_device = torch.device('cuda')
|
47 |
logger.info("Using ZeroGPU")
|
48 |
except Exception as e:
|
49 |
logger.warning(f"Failed to initialize ZeroGPU: {e}")
|
50 |
self._current_device = torch.device('cpu')
|
|
|
51 |
else:
|
52 |
self._current_device = torch.device('cpu')
|
53 |
+
logger.info("Using CPU")
|
54 |
return self._current_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def device_handler(func):
|
57 |
+
"""簡化版的 device handler"""
|
58 |
@wraps(func)
|
59 |
async def wrapper(*args, **kwargs):
|
60 |
device_mgr = DeviceManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
try:
|
62 |
+
result = await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
return result
|
|
|
64 |
except RuntimeError as e:
|
65 |
if "out of memory" in str(e) or "CUDA" in str(e):
|
66 |
+
logger.warning("ZeroGPU unavailable, falling back to CPU")
|
67 |
device_mgr._current_device = torch.device('cpu')
|
68 |
+
return await func(*args, **kwargs)
|
69 |
raise e
|
|
|
70 |
return wrapper
|