Spaces:
Running
on
Zero
Running
on
Zero
from functools import wraps | |
import torch | |
from huggingface_hub import HfApi | |
import os | |
import logging | |
import asyncio | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DeviceManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(DeviceManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if self._initialized: | |
return | |
self._initialized = True | |
self._current_device = None | |
self._zero_gpu_available = None | |
def check_zero_gpu_availability(self): | |
try: | |
# 檢查是否在 Spaces 環境中 | |
if 'SPACE_ID' not in os.environ: | |
return False | |
# 檢查是否為 Pro 用戶(ZeroGPU 可用) | |
api = HfApi() | |
space_info = api.get_space_runtime(os.environ['SPACE_ID']) | |
# 檢查是否有 ZeroGPU 資源 | |
if (hasattr(space_info, 'hardware') and | |
space_info.hardware.get('zerogpu', False)): | |
self._zero_gpu_available = True | |
return True | |
except Exception as e: | |
logger.warning(f"Error checking ZeroGPU availability: {e}") | |
self._zero_gpu_available = False | |
return False | |
def get_optimal_device(self): | |
if self._current_device is None: | |
if self.check_zero_gpu_availability(): | |
try: | |
# 特別標記這是 ZeroGPU 環境 | |
os.environ['ZERO_GPU'] = '1' | |
self._current_device = torch.device('cuda') | |
logger.info("Using ZeroGPU") | |
except Exception as e: | |
logger.warning(f"Failed to initialize ZeroGPU: {e}") | |
self._current_device = torch.device('cpu') | |
logger.info("Fallback to CPU due to ZeroGPU initialization failure") | |
else: | |
self._current_device = torch.device('cpu') | |
logger.info("Using CPU (ZeroGPU not available)") | |
return self._current_device | |
def move_to_device(self, tensor_or_model): | |
device = self.get_optimal_device() | |
if hasattr(tensor_or_model, 'to'): | |
try: | |
return tensor_or_model.to(device) | |
except Exception as e: | |
logger.warning(f"Failed to move tensor/model to {device}: {e}") | |
self._current_device = torch.device('cpu') | |
return tensor_or_model.to('cpu') | |
return tensor_or_model | |
def device_handler(func): | |
"""Decorator for handling device placement with ZeroGPU support""" | |
async def wrapper(*args, **kwargs): | |
device_mgr = DeviceManager() | |
def process_arg(arg): | |
if torch.is_tensor(arg) or hasattr(arg, 'to'): | |
return device_mgr.move_to_device(arg) | |
return arg | |
processed_args = [process_arg(arg) for arg in args] | |
processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()} | |
try: | |
# 如果函數是異步的,使用 await | |
if asyncio.iscoroutinefunction(func): | |
result = await func(*processed_args, **processed_kwargs) | |
else: | |
result = func(*processed_args, **processed_kwargs) | |
# 處理輸出 | |
if torch.is_tensor(result): | |
return device_mgr.move_to_device(result) | |
elif isinstance(result, tuple): | |
return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result) | |
return result | |
except RuntimeError as e: | |
if "out of memory" in str(e) or "CUDA" in str(e): | |
logger.warning("ZeroGPU resources unavailable, falling back to CPU") | |
device_mgr._current_device = torch.device('cpu') | |
return await wrapper(*args, **kwargs) | |
raise e | |
return wrapper |