Mehdi Cherti
update
be61cf2
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
import braceexpand
import random
import sys
def pytorch_worker_seed():
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour the seed already created for pytorch dataloader workers if it exists
return worker_info.seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
#urls = wds.shardlists.expand_urls(urls)
if type(urls) != list:
urls = list(braceexpand.braceexpand(urls))
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
self.rng.seed(self.worker_seed() + epoch)
for _ in range(self.nshards):
yield dict(url=self.rng.choice(self.urls))