#!/bin/env python # -*- coding: utf-8 -*- import os import toml import sys from pathlib import Path from collections import defaultdict, Counter from pprint import pprint def update_config(root_dir): root_dir = Path(root_dir).resolve() config_path = root_dir / "config.toml" config = toml.load(config_path) stats = defaultdict(Counter) new_subsets = [] for dataset_path in root_dir.iterdir(): if not dataset_path.is_dir() or dataset_path.name[0] == '.': continue for subset_path in dataset_path.iterdir(): subset_name = subset_path.name subset_path = dataset_path / subset_path if not subset_path.is_dir() or subset_name[0] == '.': continue # Collect the dataset information for the config.toml try: num_repeats = int(subset_name.partition('_')[0]) except ValueError: num_repeats = 1 new_subsets.append({ "image_dir": str(subset_path), "num_repeats": num_repeats }) # Accumulate statistics for each subset # First collect the extensions of the files in the subset data_files = defaultdict(set) for file in subset_path.iterdir(): ext = file.suffix if ext not in {'.txt', '.tags', '.caption', '.txt', '.jxl', '.jpg', '.jpeg', '.png', '.json'}: continue stem = file.stem.partition('.')[0] if stem == 'sample-prompts': continue data_files[stem].add(ext) # Classify the files in the subset subset_stats = stats[subset_path] for stem, exts in data_files.items(): has_caption = bool({'.txt', '.caption', 'caption', '.tags'} & exts) has_image = bool({'.jpg', '.jpeg', '.png', '.jxl'} & exts) if has_caption and has_image: subset_stats["captioned"] += 1 elif has_image: subset_stats["no_caption"] += 1 elif has_caption: subset_stats["orphans"] += 1 if 'DELETE_ORPHANS' in os.environ: print(f"Deleting orphan {subset_path / f'{stem}{ext}'}") if not 'DEBUG' in os.environ: for ext in exts: (subset_path / f"{stem}{ext}").unlink() raise NotImplementedError("UNFINISHED DO NOT USE") else: if '.toml' not in exts: for ext in exts: subset_stats[ext] += 1 # Edit the config.toml config["datasets"][0]["subsets"] = new_subsets if "DEBUG" in os.environ: print(toml.dumps(config)) else: with open(config_path, "w") as f: toml.dump(config, f) return stats if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: [DEBUG=1] [DELETE_ORPHANS=1] python script.py ") sys.exit(1) root_dir = sys.argv[1] stats = update_config(root_dir) # Print statistics for each subset for subset, subset_stats in sorted(stats.items(), key=lambda x: x[0]): print(subset, dict(subset_stats))