DawnC commited on
Commit
818a6a6
1 Parent(s): 5499a5d

Create device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +108 -0
device_manager.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import torch
3
+ from huggingface_hub import HfApi
4
+ import os
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class DeviceManager:
11
+ _instance = None
12
+
13
+ def __new__(cls):
14
+ if cls._instance is None:
15
+ cls._instance = super(DeviceManager, cls).__new__(cls)
16
+ cls._instance._initialized = False
17
+ return cls._instance
18
+
19
+ def __init__(self):
20
+ if self._initialized:
21
+ return
22
+
23
+ self._initialized = True
24
+ self._current_device = None
25
+ self._zero_gpu_available = None
26
+
27
+ def check_zero_gpu_availability(self):
28
+ try:
29
+ # 檢查是否在 Spaces 環境中
30
+ if 'SPACE_ID' not in os.environ:
31
+ return False
32
+
33
+ # 檢查是否為 Pro 用戶(ZeroGPU 可用)
34
+ api = HfApi()
35
+ space_info = api.get_space_runtime(os.environ['SPACE_ID'])
36
+
37
+ # 檢查是否有 ZeroGPU 資源
38
+ if (hasattr(space_info, 'hardware') and
39
+ space_info.hardware.get('zerogpu', False)):
40
+ self._zero_gpu_available = True
41
+ return True
42
+
43
+ except Exception as e:
44
+ logger.warning(f"Error checking ZeroGPU availability: {e}")
45
+
46
+ self._zero_gpu_available = False
47
+ return False
48
+
49
+ def get_optimal_device(self):
50
+ if self._current_device is None:
51
+ if self.check_zero_gpu_availability():
52
+ try:
53
+ # 特別標記這是 ZeroGPU 環境
54
+ os.environ['ZERO_GPU'] = '1'
55
+ self._current_device = torch.device('cuda')
56
+ logger.info("Using ZeroGPU")
57
+ except Exception as e:
58
+ logger.warning(f"Failed to initialize ZeroGPU: {e}")
59
+ self._current_device = torch.device('cpu')
60
+ logger.info("Fallback to CPU due to ZeroGPU initialization failure")
61
+ else:
62
+ self._current_device = torch.device('cpu')
63
+ logger.info("Using CPU (ZeroGPU not available)")
64
+ return self._current_device
65
+
66
+ def move_to_device(self, tensor_or_model):
67
+ device = self.get_optimal_device()
68
+ if hasattr(tensor_or_model, 'to'):
69
+ try:
70
+ return tensor_or_model.to(device)
71
+ except Exception as e:
72
+ logger.warning(f"Failed to move tensor/model to {device}: {e}")
73
+ self._current_device = torch.device('cpu')
74
+ return tensor_or_model.to('cpu')
75
+ return tensor_or_model
76
+
77
+ def device_handler(func):
78
+ """Decorator for handling device placement with ZeroGPU support"""
79
+ @wraps(func)
80
+ async def wrapper(*args, **kwargs):
81
+ device_mgr = DeviceManager()
82
+
83
+ def process_arg(arg):
84
+ if torch.is_tensor(arg) or hasattr(arg, 'to'):
85
+ return device_mgr.move_to_device(arg)
86
+ return arg
87
+
88
+ processed_args = [process_arg(arg) for arg in args]
89
+ processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()}
90
+
91
+ try:
92
+ result = await func(*processed_args, **processed_kwargs)
93
+
94
+ # 處理輸出
95
+ if torch.is_tensor(result):
96
+ return device_mgr.move_to_device(result)
97
+ elif isinstance(result, tuple):
98
+ return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result)
99
+ return result
100
+
101
+ except RuntimeError as e:
102
+ if "out of memory" in str(e) or "CUDA" in str(e):
103
+ logger.warning("ZeroGPU resources unavailable, falling back to CPU")
104
+ device_mgr._current_device = torch.device('cpu')
105
+ return await wrapper(*args, **kwargs)
106
+ raise e
107
+
108
+ return wrapper