DawnC commited on
Commit
62ed182
1 Parent(s): 6e2b666

Create device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +89 -0
device_manager.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import torch
3
+ from huggingface_hub import HfApi
4
+ import os
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class DeviceManager:
11
+ _instance = None
12
+
13
+ def __new__(cls):
14
+ if cls._instance is None:
15
+ cls._instance = super(DeviceManager, cls).__new__(cls)
16
+ cls._instance._initialized = False
17
+ return cls._instance
18
+
19
+ def __init__(self):
20
+ if self._initialized:
21
+ return
22
+
23
+ self._initialized = True
24
+ self._current_device = None
25
+ self._zero_gpu_available = None
26
+
27
+ def check_zero_gpu_availability(self):
28
+ try:
29
+ api = HfApi()
30
+ # 檢查環境變數或其他方式確認是否在 Spaces 環境
31
+ if 'SPACE_ID' in os.environ:
32
+ # 這裡可以添加更多具體的 ZeroGPU 可用性檢查
33
+ self._zero_gpu_available = True
34
+ return True
35
+ except Exception as e:
36
+ logger.warning(f"Error checking ZeroGPU availability: {e}")
37
+
38
+ self._zero_gpu_available = False
39
+ return False
40
+
41
+ def get_optimal_device(self):
42
+ if self._current_device is None:
43
+ if self.check_zero_gpu_availability():
44
+ self._current_device = torch.device('cuda')
45
+ logger.info("Using ZeroGPU")
46
+ else:
47
+ self._current_device = torch.device('cpu')
48
+ logger.info("Using CPU")
49
+ return self._current_device
50
+
51
+ def move_to_device(self, tensor_or_model):
52
+ device = self.get_optimal_device()
53
+ if hasattr(tensor_or_model, 'to'):
54
+ return tensor_or_model.to(device)
55
+ return tensor_or_model
56
+
57
+ def device_handler(func):
58
+ """Decorator for handling device placement"""
59
+ @wraps(func)
60
+ async def wrapper(*args, **kwargs):
61
+ device_mgr = DeviceManager()
62
+
63
+ # 處理輸入參數的設備轉換
64
+ def process_arg(arg):
65
+ if torch.is_tensor(arg) or hasattr(arg, 'to'):
66
+ return device_mgr.move_to_device(arg)
67
+ return arg
68
+
69
+ processed_args = [process_arg(arg) for arg in args]
70
+ processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()}
71
+
72
+ try:
73
+ result = await func(*processed_args, **processed_kwargs)
74
+
75
+ # 處理輸出結果的設備轉換
76
+ if torch.is_tensor(result):
77
+ return device_mgr.move_to_device(result)
78
+ elif isinstance(result, tuple):
79
+ return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result)
80
+ return result
81
+
82
+ except RuntimeError as e:
83
+ if "out of memory" in str(e):
84
+ logger.warning("GPU memory exceeded, falling back to CPU")
85
+ device_mgr._current_device = torch.device('cpu')
86
+ return await wrapper(*args, **kwargs)
87
+ raise e
88
+
89
+ return wrapper