|
import os |
|
import torch |
|
|
|
from datetime import timedelta |
|
|
|
|
|
RANK = int(os.getenv("RANK", "0")) |
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) |
|
|
|
|
|
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) |
|
|
|
|
|
class FakeBarrier: |
|
def wait(self): |
|
pass |
|
|
|
|
|
class FakeGroup: |
|
def __init__(self, rank, size): |
|
self._rank = rank |
|
self._size = size |
|
|
|
def allreduce(self, *args, **kwargs): |
|
return FakeBarrier() |
|
|
|
def allgather(self, inputs, local_tensor, **kwargs): |
|
assert ( |
|
len(inputs[0]) == len(local_tensor) == 1 |
|
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" |
|
for input_ in inputs: |
|
input_[0].data = local_tensor[0].data |
|
return FakeBarrier() |
|
|
|
def barrier(self, *args, **kwargs): |
|
return FakeBarrier() |
|
|
|
def size(self): |
|
return self._size |
|
|
|
def rank(self): |
|
return self._rank |
|
|
|
|
|
def initialize_torch_distributed(): |
|
if torch.cuda.is_available(): |
|
from torch.distributed import ProcessGroupNCCL |
|
|
|
|
|
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" |
|
device = RANK % torch.cuda.device_count() |
|
torch.cuda.set_device(device) |
|
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) |
|
backend = "nccl" |
|
options = ProcessGroupNCCL.Options() |
|
options.is_high_priority_stream = True |
|
options._timeout = timedelta(seconds=60) |
|
else: |
|
backend = "gloo" |
|
options = None |
|
|
|
if WORLD_SIZE == 1: |
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE |
|
else: |
|
if os.getenv("DEBUG", None) == "1": |
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE |
|
|
|
if not torch.distributed.is_initialized(): |
|
|
|
torch.distributed.init_process_group( |
|
backend=backend, |
|
world_size=WORLD_SIZE, |
|
rank=RANK, |
|
timeout=timedelta(seconds=60), |
|
pg_options=options, |
|
) |
|
else: |
|
print("torch.distributed is already initialized.") |
|
|
|
return torch.distributed.group.WORLD, RANK, WORLD_SIZE |