import torch import os, sys if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" now_dir = os.getcwd() sys.path.append(now_dir) from .logger.log import get_logger logger = get_logger("gpu") def select_device(min_memory=2047, experimental=False): if torch.cuda.is_available(): selected_gpu = 0 max_free_memory = -1 for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) free_memory = props.total_memory - torch.cuda.memory_reserved(i) if max_free_memory < free_memory: selected_gpu = i max_free_memory = free_memory free_memory_mb = max_free_memory / (1024 * 1024) if free_memory_mb < min_memory: logger.get_logger().warning( f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." ) device = torch.device("cpu") else: device = torch.device(f"cuda:{selected_gpu}") elif torch.backends.mps.is_available(): """ Currently MPS is slower than CPU while needs more memory and core utility, so only enable this for experimental use. """ if experimental: # For Apple M1/M2 chips with Metal Performance Shaders logger.warn("experimantal: found apple GPU, using MPS.") device = torch.device("mps") else: logger.info("found Apple GPU, but use CPU.") device = torch.device("cpu") else: logger.warning("no GPU found, use CPU instead") device = torch.device("cpu") return device