Spaces:
Runtime error
Runtime error
File size: 1,228 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Sized
from mmengine.dataset import DefaultSampler
from mmengine.dist import sync_random_seed
from .setup_distributed import (get_data_parallel_rank,
get_data_parallel_world_size)
class SequenceParallelSampler(DefaultSampler):
def __init__(self,
dataset: Sized,
shuffle: bool = True,
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank = get_data_parallel_rank()
world_size = get_data_parallel_world_size()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.shuffle = shuffle
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
|