File size: 2,808 Bytes
8235b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Author: Alexandre Defossez (adefossez)

"""
Start multiple process locally for DDP.
"""

import logging
import subprocess as sp
import sys

from hydra import utils

logger = logging.getLogger(__name__)


class ChildrenManager:
    def __init__(self):
        self.children = []
        self.failed = False

    def add(self, child):
        child.rank = len(self.children)
        self.children.append(child)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_value is not None:
            logger.error(
                "An exception happened while starting workers %r", exc_value)
            self.failed = True
        try:
            while self.children and not self.failed:
                for child in list(self.children):
                    try:
                        exitcode = child.wait(0.1)
                    except sp.TimeoutExpired:
                        continue
                    else:
                        self.children.remove(child)
                        if exitcode:
                            logger.error(
                                f"Worker {child.rank} died, killing all workers")
                            self.failed = True
        except KeyboardInterrupt:
            logger.error(
                "Received keyboard interrupt, trying to kill all workers.")
            self.failed = True
        for child in self.children:
            child.terminate()
        if not self.failed:
            logger.info("All workers completed successfully")


def start_ddp_workers():
    import torch as th

    world_size = th.cuda.device_count()
    if not world_size:
        logger.error(
            "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
        sys.exit(1)
    logger.info(f"Starting {world_size} worker processes for DDP.")
    with ChildrenManager() as manager:
        for rank in range(world_size):
            kwargs = {}
            argv = list(sys.argv)
            argv += [f"world_size={world_size}", f"rank={rank}"]
            if rank > 0:
                kwargs['stdin'] = sp.DEVNULL
                kwargs['stdout'] = sp.DEVNULL
                kwargs['stderr'] = sp.DEVNULL
                log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename
                log += f".{rank}"
                argv.append("hydra.job_logging.handlers.file.filename=" + log)
            manager.add(sp.Popen([sys.executable] + argv,
                                 cwd=utils.get_original_cwd(), **kwargs))
    sys.exit(int(manager.failed))