from laion_clap import create_model from laion_clap.training.data import get_data from laion_clap.training import parse_args import torch import os from tqdm import tqdm from laion_clap.training.distributed import is_master, world_info_from_env from laion_clap.utils import dataset_split def run_dataloader(): for i, batch in enumerate(tqdm(dataloader, total=data["train"].dataloader.num_samples // args.batch_size)): pass if __name__ == '__main__': args = parse_args() # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? args.amodel = args.amodel.replace("/", "-") device = torch.device('cpu') # discover initial world args early so we can log properly args.distributed = False args.local_rank, args.rank, args.world_size = world_info_from_env() if args.remotedata and is_master(args): for dataset_name in args.datasetnames: for split in dataset_split[dataset_name]: if not os.path.exists(f"./json_files/{dataset_name}/{split}"): os.makedirs(f"./json_files/{dataset_name}/{split}") os.system( f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" ) model, model_cfg = create_model( args.amodel, args.tmodel, args.pretrained, precision=args.precision, device=device, jit=args.torchscript, force_quick_gelu=args.force_quick_gelu, openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), skip_params=True, pretrained_audio=args.pretrained_audio, pretrained_text=args.pretrained_text, enable_fusion=args.enable_fusion, fusion_type=args.fusion_type ) data = get_data(args, model_cfg) dataloader, sampler = data["train"].dataloader, data["train"].sampler print('dataset size:', data["train"].dataloader.num_samples) print('batch size:', args.batch_size) print('num batches:', data["train"].dataloader.num_samples // args.batch_size) run_dataloader()