Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,760 Bytes
0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import itertools
import os
import re
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from pathlib import Path
import click
import numpy as np
from loguru import logger
from tqdm import tqdm
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.utils.file import load_filelist
# To avoid CPU overload
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
def task_generator_folder(root: Path, text_extension: str):
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
files = sorted(files)
grouped_files = defaultdict(list)
for file in tqdm(files, desc=f"Grouping {root}"):
p = str(file.parent)
try:
if isinstance(text_extension, str):
texts = [file.with_suffix(text_extension).read_text()]
else:
texts = [file.with_suffix(ext).read_text() for ext in text_extension]
except Exception as e:
logger.error(f"Failed to read text {file}: {e}")
continue
grouped_files[p].append((file, texts))
logger.info(
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
)
for name, subset in grouped_files.items():
yield name, subset, "folder"
def task_generator_filelist(filelist):
grouped_files = defaultdict(list)
for filename, speaker, _, text in load_filelist(filelist):
grouped_files[speaker].append((Path(filename), [text]))
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
for speaker, values in grouped_files.items():
yield speaker, values, "filelist"
def run_task(task):
name, subset, source = task
# Parse the files
sentences = []
for file in subset:
file, texts = file
np_file = file.with_suffix(".npy")
if np_file.exists() is False:
logger.warning(f"Can't find {np_file}")
continue
new_texts = []
for text in texts:
# Simple cleaning: replace { xxx } and < xxx > with space
text = re.sub(r"\{.*?\}", " ", text)
text = re.sub(r"<.*?>", " ", text)
text = re.sub(r"\s+", " ", text)
new_texts.append(text)
try:
semantics = np.load(np_file)
except Exception as e:
logger.error(f"Failed to parse {file}: {e}")
continue
if isinstance(semantics, np.ndarray):
semantics = semantics.tolist()
sentences.append(
Sentence(
texts=new_texts,
semantics=[Semantics(values=s) for s in semantics],
)
)
# Pack the sentences
return pack_pb_stream(
TextData(
source=source,
name=name,
sentences=sentences,
)
)
@click.command()
@click.option(
"--input",
type=click.Path(path_type=Path),
required=True,
help="A folder containing the dataset or a filelist",
multiple=True,
)
@click.option(
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
)
@click.option("--num-workers", type=int, default=16)
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
@click.option(
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
)
def main(input, output, num_workers, text_extension, shard_size):
generator_fns = []
for f in input:
assert f.exists(), f"{f} not found"
if f.is_dir():
generator_fn = task_generator_folder(f, text_extension)
else:
generator_fn = task_generator_filelist(f)
generator_fns.append(generator_fn)
generator_fn = itertools.chain(*generator_fns)
output.mkdir(parents=True, exist_ok=True)
dataset_fp = None
tar_idx = 0
written_size = 0
with Pool(num_workers) as p:
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
if dataset_fp is None:
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
dataset_fp.write(result)
written_size += len(result)
if written_size > shard_size * 1024 * 1024:
logger.info(f"Finished writing {tar_idx} shards to {output}")
dataset_fp.close()
dataset_fp = None
written_size = 0
tar_idx += 1
if dataset_fp is not None:
dataset_fp.close()
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
if __name__ == "__main__":
main()
|