|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""dataset_stats.py""" |
|
import os |
|
import json |
|
import glob |
|
import numpy as np |
|
from typing import Optional, List |
|
|
|
STAT_FILE_NAME = "dataset_stats.json" |
|
|
|
|
|
def generate_dataset_stats(data_home: os.PathLike, dataset_name: Optional[str] = None) -> None: |
|
"""Generate dataset stats for a given dataset. |
|
|
|
Args: |
|
data_home: Path to the data directory. |
|
dataset_name: Name of the dataset to (re)generate stats for. If None, generate MISSING stats for all |
|
datasets. |
|
""" |
|
stat_file = os.path.join(data_home, 'yourmt3_indexes', STAT_FILE_NAME) |
|
if os.path.exists(stat_file): |
|
print(f"Loading existing dataset stats file: {stat_file}") |
|
with open(stat_file, 'r') as f: |
|
stats = json.load(f).items() |
|
else: |
|
print(f"Creating new dataset stats file: {stat_file}") |
|
stats = {} |
|
|
|
|
|
indexes = glob.glob(os.path.join(data_home, 'yourmt3_indexes', '*_file_list.json')) |
|
for index_file in indexes: |
|
dataset_name = os.path.basename(index_file).split('_')[0] |
|
split_name = os.path.basename(index_file).split('_')[1] |
|
|