ndhieunguyen's picture
feat: remove mpi4py
77180e4
raw
history blame
2.19 kB
"""
Helpers for distributed training.
"""
import io
import os
import socket
import blobfile as bf
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 1 # 8
SETUP_RETRY_COUNT = 3
def setup_dist(rank, world_size, port="12145"):
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
# comm = MPI.COMM_WORLD
# backend = "gloo" if not th.cuda.is_available() else "nccl"
# if backend == "gloo":
# hostname = "localhost"
# else:
# hostname = socket.gethostbyname(socket.getfqdn())
# os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
# os.environ["RANK"] = str(comm.rank)
# os.environ["WORLD_SIZE"] = str(comm.size)
# port = comm.bcast(_find_free_port(), root=0)
# os.environ["MASTER_PORT"] = str(port)
# dist.init_process_group(backend=backend, init_method="env://")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = port
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
# def dev():
# """
# Get the device to use for torch.distributed.
# """
# if th.cuda.is_available():
# return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
# return th.device("cpu")
# def load_state_dict(path, **kwargs):
# """
# Load a PyTorch file without redundant fetches across MPI ranks.
# """
# if MPI.COMM_WORLD.Get_rank() == 0:
# with bf.BlobFile(path, "rb") as f:
# data = f.read()
# else:
# data = None
# data = MPI.COMM_WORLD.bcast(data)
# return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()