zhzluke96
update
d2b7e94
raw
history blame
17.6 kB
import abc
import functools
import io
import json
import logging
import os
import tarfile
import typing
import torch.utils.data
import torchaudio
import transformers
import vocos
from torchvision.datasets.utils import download_url
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
apply_character_map,
count_invalid_characters,
)
class LazyDataType(typing.TypedDict):
filepath: str
speaker: str
lang: str
text: str
class DataType(LazyDataType):
text_input_ids: torch.Tensor # (batch_size, text_len)
text_attention_mask: torch.Tensor # (batch_size, text_len)
audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100)
audio_attention_mask: torch.Tensor # (batch_size, audio_len)
class XzListTarKwargsType(typing.TypedDict):
tokenizer: typing.Union[transformers.PreTrainedTokenizer, None]
vocos_model: typing.Union[vocos.Vocos, None]
device: typing.Union[str, torch.device, None]
speakers: typing.Union[typing.Iterable[str], None]
sample_rate: typing.Union[int]
default_speaker: typing.Union[str, None]
default_lang: typing.Union[str, None]
tar_in_memory: typing.Union[bool, None]
process_ahead: typing.Union[bool, None]
class AudioFolder(torch.utils.data.Dataset, abc.ABC):
def __init__(
self,
root: str | io.BytesIO,
tokenizer: transformers.PreTrainedTokenizer | None = None,
vocos_model: vocos.Vocos | None = None,
device: str | torch.device | None = None,
speakers: typing.Iterable[str] | None = None,
sample_rate: int = 24_000,
default_speaker: str | None = None,
default_lang: str | None = None,
tar_path: str | None = None,
tar_in_memory: bool = False,
process_ahead: bool = False,
) -> None:
self.root = root
self.sample_rate = sample_rate
self.default_speaker = default_speaker
self.default_lang = default_lang
self.logger = logging.getLogger(__name__)
self.normalizer = {}
self.tokenizer = tokenizer
self.vocos = vocos_model
self.vocos_device = (
None if self.vocos is None else next(self.vocos.parameters()).device
)
self.device = device or self.vocos_device
# tar -cvf ../Xz.tar *
# tar -xf Xz.tar -C ./Xz
self.tar_path = tar_path
self.tar_file = None
self.tar_io = None
if tar_path is not None:
if tar_in_memory:
with open(tar_path, "rb") as f:
self.tar_io = io.BytesIO(f.read())
self.tar_file = tarfile.open(fileobj=self.tar_io)
else:
self.tar_file = tarfile.open(tar_path)
self.lazy_data, self.speakers = self.get_lazy_data(root, speakers)
self.text_input_ids: dict[int, torch.Tensor] = {}
self.audio_mel_specs: dict[int, torch.Tensor] = {}
if process_ahead:
for n, item in enumerate(self.lazy_data):
self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"])
self.text_input_ids[n] = self.preprocess_text(
item["text"], item["lang"]
)
if self.tar_file is not None:
self.tar_file.close()
if self.tar_io is not None:
self.tar_io.close()
@abc.abstractmethod
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ...
@staticmethod
@abc.abstractmethod
def save_config(
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
) -> None: ...
def __len__(self):
return len(self.lazy_data)
def __getitem__(self, n: int) -> DataType:
lazy_data = self.lazy_data[n]
if n in self.audio_mel_specs:
audio_mel_specs = self.audio_mel_specs[n]
text_input_ids = self.text_input_ids[n]
else:
audio_mel_specs = self.preprocess_audio(lazy_data["filepath"])
text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"])
self.audio_mel_specs[n] = audio_mel_specs
self.text_input_ids[n] = text_input_ids
if len(self.audio_mel_specs) == len(self.lazy_data):
if self.tar_file is not None:
self.tar_file.close()
if self.tar_io is not None:
self.tar_io.close()
text_attention_mask = torch.ones(
len(text_input_ids), device=text_input_ids.device
)
audio_attention_mask = torch.ones(
(len(audio_mel_specs) + 1) // 2,
device=audio_mel_specs.device,
)
return {
"filepath": lazy_data["filepath"],
"speaker": lazy_data["speaker"],
"lang": lazy_data["lang"],
"text": lazy_data["text"],
"text_input_ids": text_input_ids,
"text_attention_mask": text_attention_mask,
"audio_mel_specs": audio_mel_specs,
"audio_attention_mask": audio_attention_mask,
}
def get_lazy_data(
self,
root: str | io.BytesIO,
speakers: typing.Iterable[str] | None = None,
) -> tuple[list[LazyDataType], set[str]]:
if speakers is not None:
new_speakers = set(speakers)
else:
new_speakers = set()
lazy_data = []
raw_data = self.get_raw_data(root)
folder_path = os.path.dirname(root) if isinstance(root, str) else ""
for item in raw_data:
if "speaker" not in item:
item["speaker"] = self.default_speaker
if "lang" not in item:
item["lang"] = self.default_lang
if speakers is not None and item["speaker"] not in speakers:
continue
if speakers is None and item["speaker"] not in new_speakers:
new_speakers.add(item["speaker"])
if self.tar_file is None and isinstance(root, str):
filepath = os.path.join(folder_path, item["filepath"])
else:
filepath = item["filepath"]
lazy_data.append(
{
"filepath": filepath,
"speaker": item["speaker"],
"lang": item["lang"].lower(),
"text": item["text"],
}
)
return lazy_data, new_speakers
def preprocess_text(
self,
text: str,
lang: str,
) -> torch.Tensor:
invalid_characters = count_invalid_characters(text)
if len(invalid_characters):
# self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
text = apply_character_map(text)
# if not skip_refine_text:
# text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
# text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
# text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
# if refine_text_only:
# return text
text = f"[Stts][spk_emb]{text}[Ptts]"
# text = f'[Stts][empty_spk]{text}[Ptts]'
text_token = self.tokenizer(
text, return_tensors="pt", add_special_tokens=False
).to(device=self.device)
return text_token["input_ids"].squeeze(0)
def preprocess_audio(self, filepath: str) -> torch.Tensor:
if self.tar_file is not None:
file = self.tar_file.extractfile(filepath)
waveform, sample_rate = torchaudio.load(file)
else:
waveform, sample_rate = torchaudio.load(filepath)
waveform = waveform.to(device=self.vocos_device)
if sample_rate != self.sample_rate:
waveform = torchaudio.functional.resample(
waveform,
orig_freq=sample_rate,
new_freq=self.sample_rate,
)
mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform)
return (
mel_spec.to(device=self.device).squeeze(0).transpose(0, 1)
) # (audio_len*2, 100)
class JsonFolder(AudioFolder):
"""
In json file, each item is formatted as following example:
`{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`.
filepath is relative to the dirname of root json file.
"""
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
with open(root, "r", encoding="utf-8") as f:
raw_data = json.load(f)
return raw_data
@staticmethod
def save_config(
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
) -> None:
save_data = [item.copy() for item in lazy_data]
for item in save_data:
item["filepath"] = os.path.relpath(item["filepath"], rel_path)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(save_data, f, ensure_ascii=False, indent=4)
class ListFolder(AudioFolder):
"""
In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator.
`path/to/file.wav|John|ZH|Hello`.
filepath is relative to the dirname of root list file.
"""
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
raw_data = []
with open(root, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip().removesuffix("\n")
if len(line) == 0:
continue
filepath, speaker, lang, text = line.split(sep="|", maxsplit=3)
raw_data.append(
{
"text": text,
"filepath": filepath,
"speaker": speaker,
"lang": lang,
}
)
return raw_data
@staticmethod
def save_config(
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
) -> None:
save_data = [item.copy() for item in lazy_data]
for item in save_data:
item["filepath"] = os.path.relpath(item["filepath"], rel_path)
with open(save_path, "w", encoding="utf-8") as f:
for item in save_data:
f.write(
f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n"
)
class XzListTar(ListFolder):
def __init__(
self,
*args,
root: str | io.BytesIO,
tar_path: str | None = None,
**kwargs,
):
if isinstance(root, io.BytesIO):
assert tar_path is not None
else:
# make sure root is a list file
if not root.endswith(".list"): # folder case
if os.path.isfile(root):
raise FileExistsError(f"{root} is a file!")
elif not os.path.exists(root):
os.makedirs(root)
root = os.path.join(root, "all.list")
if isinstance(root, str) and not os.path.isfile(root):
# prepare all.list
self.concat_dataset(
save_folder=os.path.dirname(root),
langs=kwargs.get("langs", ["zh", "en"]),
)
super().__init__(root, *args, tar_path=tar_path, **kwargs)
def concat_dataset(
self, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
) -> None:
if save_folder is None:
save_folder = os.path.dirname(self.root)
if os.path.isfile(save_folder):
raise FileExistsError(f"{save_folder} already exists as a file!")
elif not os.path.exists(save_folder):
os.makedirs(save_folder)
lazy_data = []
for member in self.tar_file.getmembers():
if not member.isfile():
continue
if member.name.endswith(".list"):
print(member.name)
root_io = self.tar_file.extractfile(member)
lazy_data += ListFolder(root_io).lazy_data
if member.name.endswith(".json"):
print(member.name)
root_io = self.tar_file.extractfile(member)
lazy_data += JsonFolder(root_io).lazy_data
if langs is not None:
lazy_data = [item for item in lazy_data if item["lang"] in langs]
ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data)
JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data)
print(f"all.list and all.json are saved to {save_folder}")
class XzListFolder(ListFolder):
"""
[Xz乔希](https://space.bilibili.com/5859321)
Only look at the basename of filepath in list file. Previous folder paths are ignored.
Files are organized as `[list basename]/[file basename]`
Example tree structure:
[folder]
├── speaker_A
│ ├── 1.wav
│ └── 2.wav
├── speaker_A.list
├── speaker_B
│ ├── 1.wav
│ └── 2.wav
└── speaker_B.list
"""
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
raw_data = super().get_raw_data(root)
for item in raw_data:
item["filepath"] = os.path.join(
os.path.basename(root).removesuffix(".list"),
os.path.basename(item["filepath"]),
)
return raw_data
class AudioCollator:
def __init__(self, text_pad: int = 0, audio_pad: int = 0):
self.text_pad = text_pad
self.audio_pad = audio_pad
def __call__(self, batch: list[DataType]):
batch = [x for x in batch if x is not None]
audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch)
text_maxlen = max(len(item["text_attention_mask"]) for item in batch)
filepath = []
speaker = []
lang = []
text = []
text_input_ids = []
text_attention_mask = []
audio_mel_specs = []
audio_attention_mask = []
for x in batch:
filepath.append(x["filepath"])
speaker.append(x["speaker"])
lang.append(x["lang"])
text.append(x["text"])
text_input_ids.append(
torch.nn.functional.pad(
x["text_input_ids"],
(text_maxlen - len(x["text_input_ids"]), 0),
value=self.text_pad,
)
)
text_attention_mask.append(
torch.nn.functional.pad(
x["text_attention_mask"],
(text_maxlen - len(x["text_attention_mask"]), 0),
value=0,
)
)
audio_mel_specs.append(
torch.nn.functional.pad(
x["audio_mel_specs"],
(0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])),
value=self.audio_pad,
)
)
audio_attention_mask.append(
torch.nn.functional.pad(
x["audio_attention_mask"],
(0, audio_maxlen - len(x["audio_attention_mask"])),
value=0,
)
)
return {
"filepath": filepath,
"speaker": speaker,
"lang": lang,
"text": text,
"text_input_ids": torch.stack(text_input_ids),
"text_attention_mask": torch.stack(text_attention_mask),
"audio_mel_specs": torch.stack(audio_mel_specs),
"audio_attention_mask": torch.stack(audio_attention_mask),
}
def formalize_xz_list(src_folder: str):
for root, _, files in os.walk(src_folder):
for file in files:
if file.endswith(".list"):
filepath = os.path.join(root, file)
print(filepath)
lazy_data = XzListFolder(filepath).lazy_data
XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder)
def concat_dataset(
src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
) -> None:
if save_folder is None:
save_folder = src_folder
if os.path.isfile(save_folder):
raise FileExistsError(f"{save_folder} already exists as a file!")
elif not os.path.exists(save_folder):
os.makedirs(save_folder)
lazy_data = []
same_folder = os.path.samefile(src_folder, save_folder)
for root, _, files in os.walk(src_folder):
for file in files:
filepath = os.path.join(root, file)
if same_folder and file in ("all.list", "all.json"):
continue
if file.endswith(".list"):
print(filepath)
lazy_data += ListFolder(filepath).lazy_data
if file.endswith(".json"):
print(filepath)
lazy_data += JsonFolder(filepath).lazy_data
if langs is not None:
lazy_data = [item for item in lazy_data if item["lang"] in langs]
ListFolder.save_config(
os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder
)
JsonFolder.save_config(
os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder
)
print(f"all.list and all.json are saved to {save_folder}")