Spaces:
Runtime error
Runtime error
import sys | |
import io | |
import os | |
import re | |
import json | |
import tarfile | |
from functools import partial | |
import webdataset as wds | |
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples | |
from webdataset.filters import pipelinefilter | |
from webdataset.tariterators import url_opener, group_by_keys | |
from webdataset.handlers import reraise_exception | |
from webdataset.gopen import gopen_schemes, gopen | |
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress | |
"""Return node and worker info for PyTorch and some distributed environments.""" | |
rank = 0 | |
world_size = 1 | |
worker = 0 | |
num_workers = 1 | |
try: | |
import torch.distributed | |
if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
group = group or torch.distributed.group.WORLD | |
rank = torch.distributed.get_rank(group=group) | |
world_size = torch.distributed.get_world_size(group=group) | |
except ModuleNotFoundError: | |
pass | |
try: | |
import torch.utils.data | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is not None: | |
worker = worker_info.id | |
num_workers = worker_info.num_workers | |
except ModuleNotFoundError: | |
pass | |
return rank, world_size, worker, num_workers | |
def pytorch_worker_seed(group=None): | |
"""Compute a distinct, deterministic RNG seed for each worker and node.""" | |
rank, world_size, worker, num_workers = pytorch_worker_info(group=group) | |
return rank * 1000 + worker | |
def worker_seed_sat(group=None, seed=0): | |
return pytorch_worker_seed(group=group) + seed * 23 | |
class ConfiguredResampledShards(ResampledShards): | |
def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True): | |
from sat.helpers import print_rank0 | |
try: | |
from megatron.core.parallel_state import get_data_parallel_group | |
group = get_data_parallel_group() | |
print_rank0("Using megatron data parallel group.") | |
except: | |
from sat.mpu import get_data_parallel_group | |
try: | |
group = get_data_parallel_group() | |
print_rank0("Using sat data parallel group.") | |
except AssertionError: | |
group = None | |
print_rank0("No data parallel group is specified!") | |
worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed) | |
super().__init__(urls, nshards, worker_seed_sat_this, deterministic) | |
class SimpleDistributedWebDataset(DataPipeline): | |
def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000): | |
# set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle | |
try: | |
from sat.mpu import get_model_parallel_world_size | |
if get_model_parallel_world_size() > 1: | |
shuffle_buffer = 1 | |
except Exception: | |
pass | |
super().__init__( | |
ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly | |
tarfile_to_samples(), | |
wds.shuffle(shuffle_buffer), | |
process_fn, | |
) | |
def tar_file_iterator_with_meta( | |
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None | |
): | |
"""Iterate over tar file, yielding filename, content pairs for the given tar stream. | |
:param fileobj: byte stream suitable for tarfile | |
:param meta_names: key of different items in meta file | |
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") | |
""" | |
stream = tarfile.open(fileobj=fileobj, mode="r|*") | |
data_dir, filename = fileobj.name.rsplit("/", 1) | |
meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}} | |
if meta_stream is None: | |
meta_file_name = filename.split(".")[0] + ".meta.jsonl" | |
meta_path = os.path.join(data_dir, meta_file_name) | |
if os.path.exists(meta_path): | |
meta_stream = open(meta_path, "r") | |
else: | |
meta_file_name = meta_stream.name | |
if meta_stream is not None: | |
for lineno, line in enumerate(meta_stream): | |
meta_list = [] | |
try: | |
meta_list.append(json.loads(line)) | |
except Exception as exn: | |
from sat.helpers import print_rank0 | |
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") | |
continue | |
for item in meta_list: | |
if not item["key"] in meta_data: | |
meta_data[item["key"]] = {} | |
for meta_name in meta_names: | |
if meta_name in item: | |
meta_data[item["key"]][meta_name] = item[meta_name] | |
meta_stream.close() | |
try: | |
for tarinfo in stream: | |
fname = tarinfo.name | |
try: | |
if not tarinfo.isreg(): | |
continue | |
if fname is None: | |
continue | |
if "/" not in fname and fname.startswith("__") and fname.endswith("__"): | |
# skipping metadata for now | |
continue | |
if skip_meta is not None and re.match(skip_meta, fname): | |
continue | |
if fname.endswith(".txt") and suffix is not None: | |
data = (stream.extractfile(tarinfo).read().decode() + suffix).encode() | |
else: | |
data = stream.extractfile(tarinfo).read() | |
result = dict(fname=fname, data=data) | |
yield result | |
if fname.endswith(".id"): | |
fid = fname.split(".")[0] | |
if "-$#%@&" in fid: | |
sfid = fid.split("-$#%@&")[0] | |
else: | |
sfid = fid | |
meta_data_fid = meta_data.get(sfid, {}) | |
for meta_name in meta_names: | |
meta_fname = fid + "." + meta_name | |
meta = meta_data_fid.get(meta_name, None) | |
yield dict(fname=meta_fname, data=meta) | |
stream.members = [] | |
except Exception as exn: | |
if hasattr(exn, "args") and len(exn.args) > 0: | |
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] | |
if handler(exn): | |
continue | |
else: | |
break | |
except Exception as exn: | |
print(exn) | |
del stream | |
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception): | |
"""Expand a stream of open tar files into a stream of tar file contents. | |
This returns an iterator over (filename, file_contents). | |
""" | |
for source in data: | |
url = source["url"] | |
try: | |
assert isinstance(source, dict) | |
assert "stream" in source | |
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): | |
assert isinstance(sample, dict) and "data" in sample and "fname" in sample | |
sample["__url__"] = url | |
yield sample | |
except Exception as exn: | |
exn.args = exn.args + (source.get("stream"), source.get("url")) | |
if handler(exn): | |
continue | |
else: | |
break | |
def url_opener( | |
data, | |
handler, | |
**kw, | |
): | |
"""Open URLs and yield a stream of url+stream pairs. | |
Args: | |
data: iterator over dict(url=...) | |
handler: exception handler. | |
kw: keyword arguments for gopen.gopen. | |
Yields: | |
a stream of url+stream pairs. | |
""" | |
for sample in data: | |
assert isinstance(sample, dict), sample | |
assert "url" in sample | |
url = sample["url"] | |
try: | |
stream = gopen(url, **kw) | |
if hasattr(stream, "meta_stream"): | |
meta_stream = stream.meta_stream | |
del stream.meta_stream | |
else: | |
meta_stream = None | |
sample.update(stream=stream, meta_stream=meta_stream) | |
yield sample | |
except Exception as exn: | |
exn.args = exn.args + (url,) | |
if handler(exn): | |
continue | |
else: | |
break | |
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception): | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander_with_meta(streams, meta_names, handler) | |
samples = group_by_keys(files, handler=handler) | |
return samples | |
class MetaDistributedWebDataset(DataPipeline): | |
"""WebDataset with meta information files | |
Extra Format: | |
in webdataset (tar), for each sample there is a '.id'; | |
for each tar file, there is a '.meta.jsonl' file with the same name; | |
The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'. | |
""" | |
def __init__( | |
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None | |
): | |
# os.environ['WDS_SHOW_SEED'] = '1' | |
import torch | |
if torch.distributed.get_rank() == 0: | |
if include_dirs is not None: # /webdatasets/A,/webdatasets/C | |
other_paths = [] | |
include_dirs = include_dirs.split(",") | |
for include_dir in include_dirs: | |
if "*" in include_dir: | |
include_dir, n = include_dir.split("*") | |
n = int(n) | |
else: | |
n = 1 | |
for cur_dir, dirs, files in os.walk(include_dir): | |
for f in files: | |
if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0: | |
# other_paths.append(os.path.join(cur_dir,f)) | |
other_paths.extend([os.path.join(cur_dir, f)] * n) | |
# print(f'Adding dataset paths {",".join(other_paths)}') | |
from braceexpand import braceexpand | |
if len(path) > 0: # not "" | |
path = list(braceexpand(path)) + other_paths | |
else: | |
path = other_paths | |
path = [path] | |
else: | |
path = [ | |
None, | |
] | |
torch.distributed.broadcast_object_list(path, src=0) | |
path = path[0] | |
tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names) | |
tarfile_to_samples = pipelinefilter(tarfile_samples) | |
# if model parallel, shuffle_buffer should be 1 to disable shuffling | |
try: | |
from sat.mpu import get_model_parallel_world_size | |
if get_model_parallel_world_size() > 1: | |
shuffle_buffer = 1 | |
except Exception: | |
pass | |
super().__init__( | |
ConfiguredResampledShards(path, seed, nshards=nshards), | |
tarfile_to_samples(), | |
wds.shuffle(shuffle_buffer), | |
process_fn, | |
) | |
# rclone support | |
from webdataset.gopen import Pipe | |
def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32): | |
"""Open a URL with `curl`. | |
:param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured. | |
:param mode: file mode | |
:param bufsize: buffer size | |
""" | |
url = url.replace("rclone://", "") | |
if mode[0] == "r": | |
cmd = f"rclone cat '{url}'" | |
return Pipe( | |
cmd, | |
mode=mode, | |
shell=True, | |
bufsize=bufsize, | |
ignore_status=[141, 23], | |
) # skipcq: BAN-B604 | |
elif mode[0] == "w": | |
cmd = f"rclone cp - '{url}'" | |
return Pipe( | |
cmd, | |
mode=mode, | |
shell=True, | |
bufsize=bufsize, | |
ignore_status=[141, 26], | |
) # skipcq: BAN-B604 | |
else: | |
raise ValueError(f"{mode}: unknown mode") | |
def gopen_boto3(url, mode="rb", bufsize=8192 * 2): | |
"""Open a URL with boto3 API. | |
:param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured. | |
:param mode: file mode | |
:param bufsize: buffer size | |
""" | |
import boto3 | |
# boto3.set_stream_logger('botocore', level='DEBUG') | |
if url.startswith("boto3://"): | |
url = url.replace("boto3://", "") | |
need_meta = False | |
else: | |
url = url.replace("metaboto3://", "") | |
need_meta = True | |
endpoint_url = os.environ.get("S3_ENDPOINT_URL", None) | |
access_key = os.environ.get("S3_ACCESS_KEY_ID", None) | |
secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None) | |
if mode[0] == "r": | |
s3_client = boto3.client( | |
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key | |
) | |
bucket, key = url.split("/", 1) | |
if need_meta: | |
# download a meta json | |
meta_file_key = key.split(".")[0] + ".meta.jsonl" | |
meta_stream = io.BytesIO() | |
s3_client.download_fileobj(bucket, meta_file_key, meta_stream) | |
meta_stream.seek(0) | |
meta_stream.name = meta_file_key | |
else: | |
meta_stream = None | |
# data tar stream | |
response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional | |
response["Body"].name = key # actually not used | |
response["Body"].meta_stream = meta_stream | |
return response["Body"] | |
else: | |
raise ValueError(f"{mode}: unknown mode") | |
gopen_schemes["rclone"] = gopen_rclone | |
gopen_schemes["boto3"] = gopen_boto3 | |
gopen_schemes["metaboto3"] = gopen_boto3 | |