PawMatchAI / device_manager.py
DawnC's picture
Update device_manager.py
8e90922 verified
raw
history blame
1.81 kB
from functools import wraps
import torch
import os
import logging
import spaces
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DeviceManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DeviceManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._current_device = None
try:
if os.environ.get('SPACE_ID'):
# 使用 spaces 的 GPU wrapper 進行初始化
@spaces.GPU
def init_gpu():
return torch.device('cuda')
self._current_device = init_gpu()
logger.info("ZeroGPU initialized successfully")
else:
self._current_device = torch.device('cpu')
except Exception as e:
logger.warning(f"Failed to initialize ZeroGPU: {e}")
self._current_device = torch.device('cpu')
def get_optimal_device(self):
return self._current_device
def device_handler(func):
"""Decorator for handling device placement with ZeroGPU support"""
@spaces.GPU
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except RuntimeError as e:
if "out of memory" in str(e) or "CUDA" in str(e):
logger.warning("ZeroGPU unavailable, falling back to CPU")
device_mgr = DeviceManager()
device_mgr._current_device = torch.device('cpu')
return await func(*args, **kwargs)
raise e
return wrapper