Spaces:
Running
on
Zero
Running
on
Zero
from functools import wraps | |
import torch | |
import os | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DeviceManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(DeviceManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if self._initialized: | |
return | |
self._initialized = True | |
self._current_device = None | |
self.initialize_zero_gpu() | |
def initialize_zero_gpu(self): | |
"""初始化 ZeroGPU""" | |
try: | |
# 檢查是否在 Hugging Face Spaces 環境中 | |
if os.environ.get('SPACE_ID'): | |
# 嘗試初始化 ZeroGPU | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
# 設置必要的環境變數 | |
os.environ['ZERO_GPU'] = '1' | |
logger.info("ZeroGPU environment initialized") | |
except Exception as e: | |
logger.warning(f"Failed to initialize ZeroGPU environment: {e}") | |
def check_zero_gpu_availability(self): | |
"""檢查 ZeroGPU 是否可用""" | |
try: | |
if os.environ.get('SPACE_ID') and os.environ.get('ZERO_GPU') == '1': | |
# 確保 CUDA 運行時環境正確設置 | |
if torch.cuda.is_available(): | |
torch.cuda.init() | |
return True | |
except Exception as e: | |
logger.warning(f"ZeroGPU check failed: {e}") | |
return False | |
def get_optimal_device(self): | |
"""獲取最佳可用設備""" | |
if self._current_device is None: | |
if self.check_zero_gpu_availability(): | |
try: | |
self._current_device = torch.device('cuda') | |
logger.info("Using ZeroGPU") | |
# 嘗試進行一次小規模的 CUDA 操作來驗證 | |
torch.zeros(1).cuda() | |
except Exception as e: | |
logger.warning(f"Failed to use ZeroGPU: {e}") | |
self._current_device = torch.device('cpu') | |
logger.info("Fallback to CPU") | |
else: | |
self._current_device = torch.device('cpu') | |
logger.info("Using CPU (ZeroGPU not available)") | |
return self._current_device | |
def move_to_device(self, tensor_or_model): | |
"""將張量或模型移動到最佳設備""" | |
device = self.get_optimal_device() | |
try: | |
if hasattr(tensor_or_model, 'to'): | |
return tensor_or_model.to(device) | |
except Exception as e: | |
logger.warning(f"Failed to move to {device}, falling back to CPU: {e}") | |
self._current_device = torch.device('cpu') | |
if hasattr(tensor_or_model, 'to'): | |
return tensor_or_model.to('cpu') | |
return tensor_or_model |