File size: 6,453 Bytes
e0f66ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Entry point for dora to launch solvers for running training loops.
See more info on how to use dora: https://github.com/facebookresearch/dora
"""

import logging
import multiprocessing
import os
import sys
import typing as tp

from dora import git_save, hydra_main, XP
import flashy
import hydra
import omegaconf

from .environment import AudioCraftEnvironment
from .utils.cluster import get_slurm_parameters

logger = logging.getLogger(__name__)


def resolve_config_dset_paths(cfg):
    """Enable Dora to load manifest from git clone repository."""
    # manifest files for the different splits
    for key, value in cfg.datasource.items():
        if isinstance(value, str):
            cfg.datasource[key] = git_save.to_absolute_path(value)


def get_solver(cfg):
    from . import solvers
    # Convert batch size to batch size for each GPU
    assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
    cfg.dataset.batch_size //= flashy.distrib.world_size()
    for split in ['train', 'valid', 'evaluate', 'generate']:
        if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
            assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
            cfg.dataset[split].batch_size //= flashy.distrib.world_size()
    resolve_config_dset_paths(cfg)
    solver = solvers.get_solver(cfg)
    return solver


def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                       restore: bool = True, load_best: bool = True,
                       ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
    """Given a XP, return the Solver object.

    Args:
        xp (XP): Dora experiment for which to retrieve the solver.
        override_cfg (dict or None): If not None, should be a dict used to
            override some values in the config of `xp`. This will not impact
            the XP signature or folder. The format is different
            than the one used in Dora grids, nested keys should actually be nested dicts,
            not flattened, e.g. `{'optim': {'batch_size': 32}}`.
        restore (bool): If `True` (the default), restore state from the last checkpoint.
        load_best (bool): If `True` (the default), load the best state from the checkpoint.
        ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
        disable_fsdp (bool): if True, disables FSDP entirely. This will
            also automatically skip loading the EMA. For solver specific
            state sources, like the optimizer, you might want to
            use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
    """
    logger.info(f"Loading solver from XP {xp.sig}. "
                f"Overrides used: {xp.argv}")
    cfg = xp.cfg
    if override_cfg is not None:
        cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
    if disable_fsdp and cfg.fsdp.use:
        cfg.fsdp.use = False
        assert load_best is True
        # ignoring some keys that were FSDP sharded like model, ema, and best_state.
        # fsdp_best_state will be used in that case. When using a specific solver,
        # one is responsible for adding the relevant keys, e.g. 'optimizer'.
        # We could make something to automatically register those inside the solver, but that
        # seem overkill at this point.
        ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']

    try:
        with xp.enter():
            solver = get_solver(cfg)
            if restore:
                solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
        return solver
    finally:
        hydra.core.global_hydra.GlobalHydra.instance().clear()


def get_solver_from_sig(sig: str, *args, **kwargs):
    """Return Solver object from Dora signature, i.e. to play with it from a notebook.
    See `get_solver_from_xp` for more information.
    """
    xp = main.get_xp_from_sig(sig)
    return get_solver_from_xp(xp, *args, **kwargs)


def init_seed_and_system(cfg):
    import numpy as np
    import torch
    import random
    from audiocraft.modules.transformer import set_efficient_attention_backend

    multiprocessing.set_start_method(cfg.mp_start_method)
    logger.debug('Setting mp start method to %s', cfg.mp_start_method)
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    # torch also initialize cuda seed if available
    torch.manual_seed(cfg.seed)
    torch.set_num_threads(cfg.num_threads)
    os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
    os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
    logger.debug('Setting num threads to %d', cfg.num_threads)
    set_efficient_attention_backend(cfg.efficient_attention_backend)
    logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)


@hydra_main(config_path='../config', config_name='config', version_base='1.1')
def main(cfg):
    init_seed_and_system(cfg)

    # Setup logging both to XP specific folder, and to stderr.
    log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
    flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
    # Initialize distributed training, no need to specify anything when using Dora.
    flashy.distrib.init()
    solver = get_solver(cfg)
    if cfg.show:
        solver.show()
        return

    if cfg.execute_only:
        assert cfg.execute_inplace or cfg.continue_from is not None, \
            "Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \
            "when running with execute_only or set execute_inplace to True."
        solver.restore(replay_metrics=False)  # load checkpoint
        solver.run_one_stage(cfg.execute_only)
        return

    return solver.run()


main.dora.dir = AudioCraftEnvironment.get_dora_dir()
main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)

if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
    print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
    main.dora.shared = None

if __name__ == '__main__':
    main()