Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.distributed as dist | |
_SEQUENCE_PARALLEL_GROUP = None | |
_SEQUENCE_PARALLEL_WORLD_SIZE = None | |
_SEQUENCE_PARALLEL_RANK = None | |
_INNER_SEQUENCE_PARALLEL_GROUP = None | |
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = None | |
_INNER_SEQUENCE_PARALLEL_RANK = None | |
_DATA_PARALLEL_GROUP = None | |
_DATA_PARALLEL_WORLD_SIZE = None | |
_DATA_PARALLEL_RANK = None | |
def init_sequence_parallel(sequence_parallel_size: int = 1): | |
assert dist.is_initialized() | |
world_size: int = dist.get_world_size() | |
# enable_ds_sequence_parallel = sequence_parallel_size > 1 | |
# if enable_ds_sequence_parallel: | |
if world_size % sequence_parallel_size != 0: | |
raise RuntimeError(f'world_size ({world_size}) is not divisible by ' | |
f'sequence_parallel_size {sequence_parallel_size}') | |
num_sequence_parallel_groups: int = world_size // sequence_parallel_size | |
rank = dist.get_rank() | |
# Build the sequence parallel groups. | |
global _SEQUENCE_PARALLEL_GROUP | |
assert _SEQUENCE_PARALLEL_GROUP is None, \ | |
'sequence parallel group is already initialized' | |
for i in range(num_sequence_parallel_groups): | |
ranks = range(i * sequence_parallel_size, | |
(i + 1) * sequence_parallel_size) | |
group = dist.new_group(ranks) | |
if rank in ranks: | |
_SEQUENCE_PARALLEL_GROUP = group | |
global _DATA_PARALLEL_GROUP | |
assert _DATA_PARALLEL_GROUP is None, \ | |
'data parallel group is already initialized' | |
all_data_parallel_group_ranks = [] | |
start_rank = 0 | |
end_rank = world_size | |
for j in range(sequence_parallel_size): | |
ranks = range(start_rank + j, end_rank, sequence_parallel_size) | |
all_data_parallel_group_ranks.append(list(ranks)) | |
group = dist.new_group(ranks) | |
if rank in ranks: | |
_DATA_PARALLEL_GROUP = group | |
def init_inner_sequence_parallel(inner_sequence_parallel_size: int = 1): | |
"""Build the sequence parallel inner groups. | |
They are helpful when sp size is not evenly divided by the number of attn | |
heads. | |
""" | |
assert _SEQUENCE_PARALLEL_GROUP is not None, \ | |
('Please call `init_inner_sequence_parallel` after calling ' | |
'`init_sequence_parallel`.') | |
rank = dist.get_rank() | |
world_size: int = dist.get_world_size() | |
n_inner_group = world_size // inner_sequence_parallel_size | |
global _INNER_SEQUENCE_PARALLEL_GROUP | |
assert _INNER_SEQUENCE_PARALLEL_GROUP is None | |
for i in range(n_inner_group): | |
ranks = range(i * inner_sequence_parallel_size, | |
(i + 1) * inner_sequence_parallel_size) | |
group = dist.new_group(ranks) | |
if rank in ranks: | |
_INNER_SEQUENCE_PARALLEL_GROUP = group | |
def is_inner_sequence_parallel_initialized(): | |
return _INNER_SEQUENCE_PARALLEL_GROUP is not None | |
def get_inner_sequence_parallel_group(): | |
return _INNER_SEQUENCE_PARALLEL_GROUP | |
def get_inner_sequence_parallel_world_size(): | |
global _INNER_SEQUENCE_PARALLEL_WORLD_SIZE | |
if _INNER_SEQUENCE_PARALLEL_WORLD_SIZE is not None: | |
return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE | |
if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None): | |
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = 1 | |
else: | |
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( | |
group=get_inner_sequence_parallel_group()) | |
return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE | |
def get_inner_sequence_parallel_rank(): | |
global _INNER_SEQUENCE_PARALLEL_RANK | |
if _INNER_SEQUENCE_PARALLEL_RANK is not None: | |
return _INNER_SEQUENCE_PARALLEL_RANK | |
if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None): | |
_INNER_SEQUENCE_PARALLEL_RANK = 0 | |
else: | |
_INNER_SEQUENCE_PARALLEL_RANK = dist.get_rank( | |
group=get_inner_sequence_parallel_group()) | |
return _INNER_SEQUENCE_PARALLEL_RANK | |
def get_sequence_parallel_group(): | |
"""Get the sequence parallel group the caller rank belongs to.""" | |
return _SEQUENCE_PARALLEL_GROUP | |
def get_sequence_parallel_world_size(): | |
"""Return world size for the sequence parallel group.""" | |
global _SEQUENCE_PARALLEL_WORLD_SIZE | |
if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: | |
return _SEQUENCE_PARALLEL_WORLD_SIZE | |
if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None): | |
_SEQUENCE_PARALLEL_WORLD_SIZE = 1 | |
else: | |
_SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( | |
group=get_sequence_parallel_group()) | |
return _SEQUENCE_PARALLEL_WORLD_SIZE | |
def get_sequence_parallel_rank(): | |
"""Return my rank for the sequence parallel group.""" | |
global _SEQUENCE_PARALLEL_RANK | |
if _SEQUENCE_PARALLEL_RANK is not None: | |
return _SEQUENCE_PARALLEL_RANK | |
if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None): | |
_SEQUENCE_PARALLEL_RANK = 0 | |
else: | |
_SEQUENCE_PARALLEL_RANK = dist.get_rank( | |
group=get_sequence_parallel_group()) | |
return _SEQUENCE_PARALLEL_RANK | |
def get_data_parallel_group(): | |
"""Get the data parallel group the caller rank belongs to.""" | |
assert _DATA_PARALLEL_GROUP is not None, \ | |
'data parallel group is not initialized' | |
return _DATA_PARALLEL_GROUP | |
def get_data_parallel_world_size(): | |
"""Return world size for the data parallel group.""" | |
global _DATA_PARALLEL_WORLD_SIZE | |
if _DATA_PARALLEL_WORLD_SIZE is not None: | |
return _DATA_PARALLEL_WORLD_SIZE | |
if not dist.is_initialized(): | |
_DATA_PARALLEL_WORLD_SIZE = 1 | |
else: | |
_DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( | |
group=get_data_parallel_group()) | |
return _DATA_PARALLEL_WORLD_SIZE | |
def get_data_parallel_rank(): | |
"""Return my rank for the data parallel group.""" | |
global _DATA_PARALLEL_RANK | |
if _DATA_PARALLEL_RANK is not None: | |
return _DATA_PARALLEL_RANK | |
if not dist.is_initialized(): | |
_DATA_PARALLEL_RANK = 0 | |
else: | |
_DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) | |
return _DATA_PARALLEL_RANK | |