Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import sys | |
from fairseq.dataclass.initialize import hydra_init | |
from fairseq_cli.train import main as pre_main | |
from fairseq import distributed_utils, metrics | |
from fairseq.dataclass.configs import FairseqConfig | |
import hydra | |
import torch | |
from omegaconf import OmegaConf | |
logger = logging.getLogger("fairseq_cli.hydra_train") | |
def hydra_main(cfg: FairseqConfig) -> float: | |
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) | |
OmegaConf.set_struct(cfg, True) | |
if cfg.common.reset_logging: | |
reset_logging() # Hydra hijacks logging, fix that | |
try: | |
if cfg.common.profile: | |
with torch.cuda.profiler.profile(): | |
with torch.autograd.profiler.emit_nvtx(): | |
distributed_utils.call_main(cfg, pre_main) | |
else: | |
distributed_utils.call_main(cfg, pre_main) | |
except BaseException as e: | |
if not cfg.common.suppress_crashes: | |
raise | |
else: | |
logger.error("Crashed! " + str(e)) | |
# get best val and return - useful for sweepers | |
try: | |
best_val = metrics.get_smoothed_value( | |
"valid", cfg.checkpoint.best_checkpoint_metric | |
) | |
except: | |
best_val = None | |
if best_val is None: | |
best_val = float("inf") | |
return best_val | |
def reset_logging(): | |
root = logging.getLogger() | |
for handler in root.handlers: | |
root.removeHandler(handler) | |
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setFormatter( | |
logging.Formatter( | |
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
) | |
root.addHandler(handler) | |
def cli_main(): | |
try: | |
from hydra._internal.utils import get_args | |
cfg_name = get_args().config_name or "config" | |
except: | |
logger.warning("Failed to get config name from hydra args") | |
cfg_name = "config" | |
hydra_init(cfg_name) | |
hydra_main() | |
if __name__ == "__main__": | |
cli_main() | |