from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info import braceexpand import random import sys def pytorch_worker_seed(): """get dataloader worker seed from pytorch""" worker_info = get_worker_info() if worker_info is not None: # favour the seed already created for pytorch dataloader workers if it exists return worker_info.seed # fallback to wds rank based seed return wds.utils.pytorch_worker_seed() class SharedEpoch: def __init__(self, epoch: int = 0): self.shared_epoch = Value('i', epoch) def set_value(self, epoch): self.shared_epoch.value = epoch def get_value(self): return self.shared_epoch.value class ResampledShards2(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( self, urls, nshards=sys.maxsize, worker_seed=None, deterministic=False, epoch=-1, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string """ super().__init__() #urls = wds.shardlists.expand_urls(urls) if type(urls) != list: urls = list(braceexpand.braceexpand(urls)) self.urls = urls assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed self.deterministic = deterministic self.epoch = epoch def __iter__(self): """Return an iterator over the shards.""" if isinstance(self.epoch, SharedEpoch): epoch = self.epoch.get_value() else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch if self.deterministic: # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed self.rng.seed(self.worker_seed() + epoch) for _ in range(self.nshards): yield dict(url=self.rng.choice(self.urls))