|
import os |
|
import json |
|
import random |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def prepare_data(data_original, save_json_train, save_json_valid, save_json_test, split_ratio=[80, 10, 10], seed=12): |
|
|
|
random.seed(seed) |
|
|
|
|
|
if os.path.exists(save_json_train) and os.path.exists(save_json_valid) and os.path.exists(save_json_test): |
|
logger.info("Preparation completed in previous run, skipping.") |
|
return |
|
|
|
|
|
wav_list = [] |
|
labels = os.listdir(data_original) |
|
|
|
for label in labels: |
|
label_dir = os.path.join(data_original, label) |
|
if os.path.isdir(label_dir): |
|
for audio_file in os.listdir(label_dir): |
|
if audio_file.endswith('.wav'): |
|
wav_file = os.path.join(label_dir, audio_file) |
|
if os.path.isfile(wav_file): |
|
wav_list.append((wav_file, label)) |
|
else: |
|
logger.warning(f"Skipping invalid audio file: {wav_file}") |
|
|
|
|
|
random.shuffle(wav_list) |
|
n_total = len(wav_list) |
|
n_train = n_total * split_ratio[0] // 100 |
|
n_valid = n_total * split_ratio[1] // 100 |
|
|
|
train_set = wav_list[:n_train] |
|
valid_set = wav_list[n_train:n_train + n_valid] |
|
test_set = wav_list[n_train + n_valid:] |
|
|
|
|
|
create_json(train_set, save_json_train) |
|
create_json(valid_set, save_json_valid) |
|
create_json(test_set, save_json_test) |
|
|
|
logger.info(f"Created {save_json_train}, {save_json_valid}, and {save_json_test}") |
|
|
|
def create_json(data, json_file): |
|
data_dict = {str(idx): {'wav': wav, 'label': label} for idx, (wav, label) in enumerate(data)} |
|
with open(json_file, 'w') as f: |
|
json.dump(data_dict, f) |
|
|