PoTaTo721's picture
update to 1.2
69e8a46
raw
history blame
2.93 kB
import os
from glob import glob
from pathlib import Path
from typing import Union
from loguru import logger
from natsort import natsorted
AUDIO_EXTENSIONS = {
".mp3",
".wav",
".flac",
".ogg",
".m4a",
".wma",
".aac",
".aiff",
".aif",
".aifc",
}
def list_files(
path: Union[Path, str],
extensions: set[str] = None,
recursive: bool = False,
sort: bool = True,
) -> list[Path]:
"""List files in a directory.
Args:
path (Path): Path to the directory.
extensions (set, optional): Extensions to filter. Defaults to None.
recursive (bool, optional): Whether to search recursively. Defaults to False.
sort (bool, optional): Whether to sort the files. Defaults to True.
Returns:
list: List of files.
"""
if isinstance(path, str):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Directory {path} does not exist.")
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
if sort:
files = natsorted(files)
return files
def get_latest_checkpoint(path: Path | str) -> Path | None:
# Find the latest checkpoint
ckpt_dir = Path(path)
if ckpt_dir.exists() is False:
return None
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
if len(ckpts) == 0:
return None
return ckpts[-1]
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
"""
Load a Bert-VITS2 style filelist.
"""
files = set()
results = []
count_duplicated, count_not_found = 0, 0
LANGUAGE_TO_LANGUAGES = {
"zh": ["zh", "en"],
"jp": ["jp", "en"],
"en": ["en"],
}
with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
splits = line.strip().split("|", maxsplit=3)
if len(splits) != 4:
logger.warning(f"Invalid line: {line}")
continue
filename, speaker, language, text = splits
file = Path(filename)
language = language.strip().lower()
if language == "ja":
language = "jp"
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
languages = LANGUAGE_TO_LANGUAGES[language]
if file in files:
logger.warning(f"Duplicated file: {file}")
count_duplicated += 1
continue
if not file.exists():
logger.warning(f"File not found: {file}")
count_not_found += 1
continue
results.append((file, speaker, languages, text))
if count_duplicated > 0:
logger.warning(f"Total duplicated files: {count_duplicated}")
if count_not_found > 0:
logger.warning(f"Total files not found: {count_not_found}")
return results