import webdataset as wds import soundfile as sf import io import os import random import copy from tqdm import tqdm import shutil import argparse import traceback import logging import json from laion_clap import tokenize def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--tar-path", type=str, default=None, help="Path to the tars", ) parser.add_argument( "--start", type=int, default=0, help="start from tar-path + start", ) parser.add_argument( "--end", type=int, default=99999, help="end with tar-path + end", ) parser.add_argument( "--exclude", nargs='+', default=None, help="exclude tar-path + exclude", ) parser.add_argument( "--batch-size", type=int, default=1, ) parser.add_argument( "--order", default=False, action='store_true', help="if keep the search order accendingly", ) args = parser.parse_args() return args def log_and_continue(exn): """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True def preprocess( sample, ): """ Preprocess a single sample for wdsdataloader. """ audio_ext = "flac" text_ext = "json" audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) sample["waveform"] = audio_data texts = json_dict_raw["text"] if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: texts = random.choice(texts) sample["raw_text"] = texts sample["text"] = tokenize(texts) return sample if __name__ == "__main__": args = parse_args() tar_path = args.tar_path idx_list = list(range(args.start, args.end)) if args.exclude != None: for x in args.exclude: idx_list.remove(x) if not args.order: random.shuffle(idx_list) if "aws" in tar_path: args.local = False if args.local: input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list] else: input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list] pipeline = [wds.SimpleShardList(input_shards)] pipeline.extend( [ wds.split_by_node, wds.split_by_worker, wds.tarfile_to_samples(handler=log_and_continue), wds.map(preprocess), wds.to_tuple("__url__", "__key__", "waveform"), wds.batched(1), ] ) dataset = wds.DataPipeline(*pipeline) dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) old_k = 0 old_batch = None try: for k, batch in tqdm(enumerate(dataloader)): print("k:", k) print("batch:", batch) old_k = k old_batch = copy.deepcopy(batch) except: with open("check_tar_log.txt","a") as file: traceback.print_exc(file = file) print("old_k:", old_k) print("old_batch:", old_batch) pass