# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import json import os import torch import torch.distributed as dist from torch import nn def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def init_distributed_mode(args): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args["rank"] = int(os.environ["RANK"]) args["world_size"] = int(os.environ["WORLD_SIZE"]) args["gpu"] = int(os.environ["LOCAL_RANK"]) elif "SLURM_PROCID" in os.environ: args["rank"] = int(os.environ["SLURM_PROCID"]) args["gpu"] = args["rank"] % torch.cuda.device_count() else: print("Not using distributed mode") args["distributed"] = False return args["distributed"] = True torch.cuda.set_device(args["gpu"]) args["dist_backend"] = "nccl" print( "| distributed init (rank {}): {}".format(args["rank"], args["dist_url"]), flush=True, ) torch.distributed.init_process_group( backend=args["dist_backend"], init_method=args["dist_url"], world_size=args["world_size"], rank=args["rank"], ) torch.distributed.barrier() setup_for_distributed(args["rank"] == 0) def save_result(result, directory, file_name): rank_path = os.path.join(directory, "{}_rank_{}.json".format(file_name, get_rank())) main_path = os.path.join(directory, "{}.json".format(file_name)) json.dump(result, open(rank_path, "w")) if is_dist_avail_and_initialized(): dist.barrier() if is_main_process(): result = [] for rank in range(get_world_size()): rank_path = os.path.join( directory, "{}_rank_{}.json".format(file_name, rank) ) rank_res = json.load(open(rank_path, "r")) result += rank_res json.dump(result, open(main_path, "w")) if is_dist_avail_and_initialized(): dist.barrier() def add_weight_decay(model: nn.Module, weight_decay: float) -> None: decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # skip weight_decay for momentum models if len(param.shape) == 1 or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) return [ {"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}, ]