import datasets import os from tqdm import tqdm import webdataset as wds import json DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train" OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile" SAMPLE_PER_SHARD = 100000 if __name__ == "__main__": os.makedirs(OUT_DIR) print("load dataset...") pile = datasets.load_from_disk(DATASET_ROOT) total_num = pile.num_rows print("total num:", total_num) num = 0 pbar = tqdm(total=total_num) with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink: for sample in pile.iter(4096): for text, meta in zip(sample["text"], sample["meta"]): pbar.update(1) if meta.get("pile_set_name", None) == "Github": continue num += 1 sink.write({ '__key__': str(num), 'txt': text.encode("utf-8"), 'json': json.dumps(meta, indent=4).encode("utf-8"), }) print(f"{num} out of {total_num} is written")