Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Author: Alexandre Defossez (adefossez) | |
""" | |
Start multiple process locally for DDP. | |
""" | |
import logging | |
import subprocess as sp | |
import sys | |
from hydra import utils | |
logger = logging.getLogger(__name__) | |
class ChildrenManager: | |
def __init__(self): | |
self.children = [] | |
self.failed = False | |
def add(self, child): | |
child.rank = len(self.children) | |
self.children.append(child) | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
if exc_value is not None: | |
logger.error( | |
"An exception happened while starting workers %r", exc_value) | |
self.failed = True | |
try: | |
while self.children and not self.failed: | |
for child in list(self.children): | |
try: | |
exitcode = child.wait(0.1) | |
except sp.TimeoutExpired: | |
continue | |
else: | |
self.children.remove(child) | |
if exitcode: | |
logger.error( | |
f"Worker {child.rank} died, killing all workers") | |
self.failed = True | |
except KeyboardInterrupt: | |
logger.error( | |
"Received keyboard interrupt, trying to kill all workers.") | |
self.failed = True | |
for child in self.children: | |
child.terminate() | |
if not self.failed: | |
logger.info("All workers completed successfully") | |
def start_ddp_workers(): | |
import torch as th | |
world_size = th.cuda.device_count() | |
if not world_size: | |
logger.error( | |
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") | |
sys.exit(1) | |
logger.info(f"Starting {world_size} worker processes for DDP.") | |
with ChildrenManager() as manager: | |
for rank in range(world_size): | |
kwargs = {} | |
argv = list(sys.argv) | |
argv += [f"world_size={world_size}", f"rank={rank}"] | |
if rank > 0: | |
kwargs['stdin'] = sp.DEVNULL | |
kwargs['stdout'] = sp.DEVNULL | |
kwargs['stderr'] = sp.DEVNULL | |
log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename | |
log += f".{rank}" | |
argv.append("hydra.job_logging.handlers.file.filename=" + log) | |
manager.add(sp.Popen([sys.executable] + argv, | |
cwd=utils.get_original_cwd(), **kwargs)) | |
sys.exit(int(manager.failed)) | |