DawnC commited on
Commit
58bf731
1 Parent(s): c16c12e

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +3 -24
device_manager.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from huggingface_hub import HfApi
4
  import os
5
  import logging
6
- import asyncio
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
@@ -23,7 +22,6 @@ class DeviceManager:
23
 
24
  self._initialized = True
25
  self._current_device = None
26
- self._zero_gpu_available = None
27
 
28
  def check_zero_gpu_availability(self):
29
  try:
@@ -31,12 +29,9 @@ class DeviceManager:
31
  api = HfApi()
32
  space_info = api.get_space_runtime(os.environ['SPACE_ID'])
33
  if hasattr(space_info, 'hardware') and space_info.hardware.get('zerogpu', False):
34
- self._zero_gpu_available = True
35
  return True
36
  except Exception as e:
37
  logger.warning(f"Error checking ZeroGPU availability: {e}")
38
-
39
- self._zero_gpu_available = False
40
  return False
41
 
42
  def get_optimal_device(self):
@@ -45,26 +40,10 @@ class DeviceManager:
45
  try:
46
  self._current_device = torch.device('cuda')
47
  logger.info("Using ZeroGPU")
48
- except Exception as e:
49
- logger.warning(f"Failed to initialize ZeroGPU: {e}")
50
  self._current_device = torch.device('cpu')
 
51
  else:
52
  self._current_device = torch.device('cpu')
53
  logger.info("Using CPU")
54
- return self._current_device
55
-
56
- def device_handler(func):
57
- """簡化版的 device handler"""
58
- @wraps(func)
59
- async def wrapper(*args, **kwargs):
60
- device_mgr = DeviceManager()
61
- try:
62
- result = await func(*args, **kwargs)
63
- return result
64
- except RuntimeError as e:
65
- if "out of memory" in str(e) or "CUDA" in str(e):
66
- logger.warning("ZeroGPU unavailable, falling back to CPU")
67
- device_mgr._current_device = torch.device('cpu')
68
- return await func(*args, **kwargs)
69
- raise e
70
- return wrapper
 
3
  from huggingface_hub import HfApi
4
  import os
5
  import logging
 
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
 
22
 
23
  self._initialized = True
24
  self._current_device = None
 
25
 
26
  def check_zero_gpu_availability(self):
27
  try:
 
29
  api = HfApi()
30
  space_info = api.get_space_runtime(os.environ['SPACE_ID'])
31
  if hasattr(space_info, 'hardware') and space_info.hardware.get('zerogpu', False):
 
32
  return True
33
  except Exception as e:
34
  logger.warning(f"Error checking ZeroGPU availability: {e}")
 
 
35
  return False
36
 
37
  def get_optimal_device(self):
 
40
  try:
41
  self._current_device = torch.device('cuda')
42
  logger.info("Using ZeroGPU")
43
+ except Exception:
 
44
  self._current_device = torch.device('cpu')
45
+ logger.info("Failed to use ZeroGPU, falling back to CPU")
46
  else:
47
  self._current_device = torch.device('cpu')
48
  logger.info("Using CPU")
49
+ return self._current_device