CogVideo / sat /sgm /webds.py
eysho's picture
Upload folder using huggingface_hub
e276be2 verified
import sys
import io
import os
import re
import json
import tarfile
from functools import partial
import webdataset as wds
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples
from webdataset.filters import pipelinefilter
from webdataset.tariterators import url_opener, group_by_keys
from webdataset.handlers import reraise_exception
from webdataset.gopen import gopen_schemes, gopen
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
try:
import torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
except ModuleNotFoundError:
pass
try:
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
return rank, world_size, worker, num_workers
def pytorch_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
return rank * 1000 + worker
def worker_seed_sat(group=None, seed=0):
return pytorch_worker_seed(group=group) + seed * 23
class ConfiguredResampledShards(ResampledShards):
def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
from sat.helpers import print_rank0
try:
from megatron.core.parallel_state import get_data_parallel_group
group = get_data_parallel_group()
print_rank0("Using megatron data parallel group.")
except:
from sat.mpu import get_data_parallel_group
try:
group = get_data_parallel_group()
print_rank0("Using sat data parallel group.")
except AssertionError:
group = None
print_rank0("No data parallel group is specified!")
worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
super().__init__(urls, nshards, worker_seed_sat_this, deterministic)
class SimpleDistributedWebDataset(DataPipeline):
def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
# set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
def tar_file_iterator_with_meta(
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param meta_names: key of different items in meta file
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode="r|*")
data_dir, filename = fileobj.name.rsplit("/", 1)
meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
if meta_stream is None:
meta_file_name = filename.split(".")[0] + ".meta.jsonl"
meta_path = os.path.join(data_dir, meta_file_name)
if os.path.exists(meta_path):
meta_stream = open(meta_path, "r")
else:
meta_file_name = meta_stream.name
if meta_stream is not None:
for lineno, line in enumerate(meta_stream):
meta_list = []
try:
meta_list.append(json.loads(line))
except Exception as exn:
from sat.helpers import print_rank0
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
continue
for item in meta_list:
if not item["key"] in meta_data:
meta_data[item["key"]] = {}
for meta_name in meta_names:
if meta_name in item:
meta_data[item["key"]][meta_name] = item[meta_name]
meta_stream.close()
try:
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if "/" not in fname and fname.startswith("__") and fname.endswith("__"):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
if fname.endswith(".txt") and suffix is not None:
data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
else:
data = stream.extractfile(tarinfo).read()
result = dict(fname=fname, data=data)
yield result
if fname.endswith(".id"):
fid = fname.split(".")[0]
if "-$#%@&" in fid:
sfid = fid.split("-$#%@&")[0]
else:
sfid = fid
meta_data_fid = meta_data.get(sfid, {})
for meta_name in meta_names:
meta_fname = fid + "." + meta_name
meta = meta_data_fid.get(meta_name, None)
yield dict(fname=meta_fname, data=meta)
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
except Exception as exn:
print(exn)
del stream
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def url_opener(
data,
handler,
**kw,
):
"""Open URLs and yield a stream of url+stream pairs.
Args:
data: iterator over dict(url=...)
handler: exception handler.
kw: keyword arguments for gopen.gopen.
Yields:
a stream of url+stream pairs.
"""
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
try:
stream = gopen(url, **kw)
if hasattr(stream, "meta_stream"):
meta_stream = stream.meta_stream
del stream.meta_stream
else:
meta_stream = None
sample.update(stream=stream, meta_stream=meta_stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
streams = url_opener(src, handler=handler)
files = tar_file_expander_with_meta(streams, meta_names, handler)
samples = group_by_keys(files, handler=handler)
return samples
class MetaDistributedWebDataset(DataPipeline):
"""WebDataset with meta information files
Extra Format:
in webdataset (tar), for each sample there is a '.id';
for each tar file, there is a '.meta.jsonl' file with the same name;
The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
"""
def __init__(
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
):
# os.environ['WDS_SHOW_SEED'] = '1'
import torch
if torch.distributed.get_rank() == 0:
if include_dirs is not None: # /webdatasets/A,/webdatasets/C
other_paths = []
include_dirs = include_dirs.split(",")
for include_dir in include_dirs:
if "*" in include_dir:
include_dir, n = include_dir.split("*")
n = int(n)
else:
n = 1
for cur_dir, dirs, files in os.walk(include_dir):
for f in files:
if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0:
# other_paths.append(os.path.join(cur_dir,f))
other_paths.extend([os.path.join(cur_dir, f)] * n)
# print(f'Adding dataset paths {",".join(other_paths)}')
from braceexpand import braceexpand
if len(path) > 0: # not ""
path = list(braceexpand(path)) + other_paths
else:
path = other_paths
path = [path]
else:
path = [
None,
]
torch.distributed.broadcast_object_list(path, src=0)
path = path[0]
tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
tarfile_to_samples = pipelinefilter(tarfile_samples)
# if model parallel, shuffle_buffer should be 1 to disable shuffling
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed, nshards=nshards),
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
# rclone support
from webdataset.gopen import Pipe
def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32):
"""Open a URL with `curl`.
:param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
url = url.replace("rclone://", "")
if mode[0] == "r":
cmd = f"rclone cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"rclone cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
"""Open a URL with boto3 API.
:param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
import boto3
# boto3.set_stream_logger('botocore', level='DEBUG')
if url.startswith("boto3://"):
url = url.replace("boto3://", "")
need_meta = False
else:
url = url.replace("metaboto3://", "")
need_meta = True
endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
access_key = os.environ.get("S3_ACCESS_KEY_ID", None)
secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)
if mode[0] == "r":
s3_client = boto3.client(
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
)
bucket, key = url.split("/", 1)
if need_meta:
# download a meta json
meta_file_key = key.split(".")[0] + ".meta.jsonl"
meta_stream = io.BytesIO()
s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
meta_stream.seek(0)
meta_stream.name = meta_file_key
else:
meta_stream = None
# data tar stream
response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional
response["Body"].name = key # actually not used
response["Body"].meta_stream = meta_stream
return response["Body"]
else:
raise ValueError(f"{mode}: unknown mode")
gopen_schemes["rclone"] = gopen_rclone
gopen_schemes["boto3"] = gopen_boto3
gopen_schemes["metaboto3"] = gopen_boto3