File size: 4,910 Bytes
69e8a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import itertools
import os
import re
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from pathlib import Path

import click
import numpy as np
from loguru import logger
from tqdm import tqdm

from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.utils.file import load_filelist

# To avoid CPU overload
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"


def task_generator_folder(root: Path, text_extension: str):
    files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
    files = sorted(files)

    grouped_files = defaultdict(list)
    for file in tqdm(files, desc=f"Grouping {root}"):
        p = str(file.parent)
        speaker = file.parent.name

        try:
            if isinstance(text_extension, str):
                texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
            else:
                texts = [
                    file.with_suffix(ext).read_text(encoding="utf-8")
                    for ext in text_extension
                ]
        except Exception as e:
            logger.error(f"Failed to read text {file}: {e}")
            continue

        grouped_files[p].append((speaker, file, texts))

    logger.info(
        f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
    )

    for i in grouped_files.values():
        subset = [(f, t) for _, f, t in i]
        yield i[0][0], subset, "folder"


def task_generator_filelist(filelist):
    grouped_files = defaultdict(list)
    for filename, speaker, _, text in load_filelist(filelist):
        grouped_files[speaker].append((Path(filename), [text]))

    logger.info(f"Found {len(grouped_files)} groups in {filelist}")
    for speaker, values in grouped_files.items():
        yield speaker, values, "filelist"


def run_task(task):
    name, subset, source = task

    # Parse the files
    sentences = []
    for file, texts in subset:
        np_file = file.with_suffix(".npy")
        if np_file.exists() is False:
            logger.warning(f"Can't find {np_file}")
            continue

        new_texts = []

        for text in texts:
            # Simple cleaning: replace { xxx } and < xxx > with space
            text = re.sub(r"\{.*?\}", " ", text)
            text = re.sub(r"<.*?>", " ", text)
            text = re.sub(r"\s+", " ", text)
            new_texts.append(text)

        try:
            semantics = np.load(np_file)
        except Exception as e:
            logger.error(f"Failed to parse {file}: {e}")
            continue

        if isinstance(semantics, np.ndarray):
            semantics = semantics.tolist()

        sentences.append(
            Sentence(
                texts=new_texts,
                semantics=[Semantics(values=s) for s in semantics],
            )
        )

    # Pack the sentences
    return pack_pb_stream(
        TextData(
            source=source,
            name=name,
            sentences=sentences,
        )
    )


@click.command()
@click.option(
    "--input",
    type=click.Path(path_type=Path),
    required=True,
    help="A folder containing the dataset or a filelist",
    multiple=True,
)
@click.option(
    "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
)
@click.option("--num-workers", type=int, default=16)
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
@click.option(
    "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
)
def main(input, output, num_workers, text_extension, shard_size):
    generator_fns = []

    for f in input:
        assert f.exists(), f"{f} not found"

        if f.is_dir():
            generator_fn = task_generator_folder(f, text_extension)
        else:
            generator_fn = task_generator_filelist(f)

        generator_fns.append(generator_fn)

    generator_fn = itertools.chain(*generator_fns)
    output.mkdir(parents=True, exist_ok=True)

    dataset_fp = None
    tar_idx = 0
    written_size = 0

    with Pool(num_workers) as p:
        for result in tqdm(p.imap_unordered(run_task, generator_fn)):
            if dataset_fp is None:
                dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")

            dataset_fp.write(result)
            written_size += len(result)

            if written_size > shard_size * 1024 * 1024:
                logger.info(f"Finished writing {tar_idx} shards to {output}")
                dataset_fp.close()
                dataset_fp = None
                written_size = 0
                tar_idx += 1

    if dataset_fp is not None:
        dataset_fp.close()

    logger.info(f"Finished writing {tar_idx + 1} shards to {output}")


if __name__ == "__main__":
    main()