Spaces:
Sleeping
Sleeping
import webdataset as wds | |
import soundfile as sf | |
import io | |
import os | |
import random | |
import copy | |
from tqdm import tqdm | |
import shutil | |
import argparse | |
import traceback | |
import logging | |
import json | |
from laion_clap import tokenize | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--tar-path", | |
type=str, | |
default=None, | |
help="Path to the tars", | |
) | |
parser.add_argument( | |
"--start", | |
type=int, | |
default=0, | |
help="start from tar-path + start", | |
) | |
parser.add_argument( | |
"--end", | |
type=int, | |
default=99999, | |
help="end with tar-path + end", | |
) | |
parser.add_argument( | |
"--exclude", | |
nargs='+', | |
default=None, | |
help="exclude tar-path + exclude", | |
) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
default=1, | |
) | |
parser.add_argument( | |
"--order", | |
default=False, | |
action='store_true', | |
help="if keep the search order accendingly", | |
) | |
args = parser.parse_args() | |
return args | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, isssue a warning, and continue.""" | |
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
return True | |
def preprocess( | |
sample, | |
): | |
""" | |
Preprocess a single sample for wdsdataloader. | |
""" | |
audio_ext = "flac" | |
text_ext = "json" | |
audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) | |
json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) | |
sample["waveform"] = audio_data | |
texts = json_dict_raw["text"] | |
if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: | |
texts = random.choice(texts) | |
sample["raw_text"] = texts | |
sample["text"] = tokenize(texts) | |
return sample | |
if __name__ == "__main__": | |
args = parse_args() | |
tar_path = args.tar_path | |
idx_list = list(range(args.start, args.end)) | |
if args.exclude != None: | |
for x in args.exclude: | |
idx_list.remove(x) | |
if not args.order: | |
random.shuffle(idx_list) | |
if "aws" in tar_path: | |
args.local = False | |
if args.local: | |
input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list] | |
else: | |
input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list] | |
pipeline = [wds.SimpleShardList(input_shards)] | |
pipeline.extend( | |
[ | |
wds.split_by_node, | |
wds.split_by_worker, | |
wds.tarfile_to_samples(handler=log_and_continue), | |
wds.map(preprocess), | |
wds.to_tuple("__url__", "__key__", "waveform"), | |
wds.batched(1), | |
] | |
) | |
dataset = wds.DataPipeline(*pipeline) | |
dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) | |
old_k = 0 | |
old_batch = None | |
try: | |
for k, batch in tqdm(enumerate(dataloader)): | |
print("k:", k) | |
print("batch:", batch) | |
old_k = k | |
old_batch = copy.deepcopy(batch) | |
except: | |
with open("check_tar_log.txt","a") as file: | |
traceback.print_exc(file = file) | |
print("old_k:", old_k) | |
print("old_batch:", old_batch) | |
pass | |