fish-speech-1 / tools /llama /build_dataset.py
PoTaTo721's picture
update to 1.2
69e8a46
raw
history blame
4.91 kB
import itertools
import os
import re
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from pathlib import Path
import click
import numpy as np
from loguru import logger
from tqdm import tqdm
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.utils.file import load_filelist
# To avoid CPU overload
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
def task_generator_folder(root: Path, text_extension: str):
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
files = sorted(files)
grouped_files = defaultdict(list)
for file in tqdm(files, desc=f"Grouping {root}"):
p = str(file.parent)
speaker = file.parent.name
try:
if isinstance(text_extension, str):
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
else:
texts = [
file.with_suffix(ext).read_text(encoding="utf-8")
for ext in text_extension
]
except Exception as e:
logger.error(f"Failed to read text {file}: {e}")
continue
grouped_files[p].append((speaker, file, texts))
logger.info(
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
)
for i in grouped_files.values():
subset = [(f, t) for _, f, t in i]
yield i[0][0], subset, "folder"
def task_generator_filelist(filelist):
grouped_files = defaultdict(list)
for filename, speaker, _, text in load_filelist(filelist):
grouped_files[speaker].append((Path(filename), [text]))
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
for speaker, values in grouped_files.items():
yield speaker, values, "filelist"
def run_task(task):
name, subset, source = task
# Parse the files
sentences = []
for file, texts in subset:
np_file = file.with_suffix(".npy")
if np_file.exists() is False:
logger.warning(f"Can't find {np_file}")
continue
new_texts = []
for text in texts:
# Simple cleaning: replace { xxx } and < xxx > with space
text = re.sub(r"\{.*?\}", " ", text)
text = re.sub(r"<.*?>", " ", text)
text = re.sub(r"\s+", " ", text)
new_texts.append(text)
try:
semantics = np.load(np_file)
except Exception as e:
logger.error(f"Failed to parse {file}: {e}")
continue
if isinstance(semantics, np.ndarray):
semantics = semantics.tolist()
sentences.append(
Sentence(
texts=new_texts,
semantics=[Semantics(values=s) for s in semantics],
)
)
# Pack the sentences
return pack_pb_stream(
TextData(
source=source,
name=name,
sentences=sentences,
)
)
@click.command()
@click.option(
"--input",
type=click.Path(path_type=Path),
required=True,
help="A folder containing the dataset or a filelist",
multiple=True,
)
@click.option(
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
)
@click.option("--num-workers", type=int, default=16)
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
@click.option(
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
)
def main(input, output, num_workers, text_extension, shard_size):
generator_fns = []
for f in input:
assert f.exists(), f"{f} not found"
if f.is_dir():
generator_fn = task_generator_folder(f, text_extension)
else:
generator_fn = task_generator_filelist(f)
generator_fns.append(generator_fn)
generator_fn = itertools.chain(*generator_fns)
output.mkdir(parents=True, exist_ok=True)
dataset_fp = None
tar_idx = 0
written_size = 0
with Pool(num_workers) as p:
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
if dataset_fp is None:
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
dataset_fp.write(result)
written_size += len(result)
if written_size > shard_size * 1024 * 1024:
logger.info(f"Finished writing {tar_idx} shards to {output}")
dataset_fp.close()
dataset_fp = None
written_size = 0
tar_idx += 1
if dataset_fp is not None:
dataset_fp.close()
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
if __name__ == "__main__":
main()