|
import random |
|
from dataclasses import dataclass |
|
from itertools import chain |
|
from pathlib import Path |
|
from random import Random |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
import pyarrow.parquet as pq |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets.download.streaming_download_manager import xopen |
|
from huggingface_hub import HfApi |
|
from lightning import LightningDataModule |
|
from torch.distributed import get_rank, get_world_size, is_initialized |
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info |
|
from transformers import AutoTokenizer |
|
|
|
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID |
|
from fish_speech.datasets.protos.text_data_pb2 import SampledData |
|
from fish_speech.datasets.protos.text_data_stream import read_pb_stream |
|
from fish_speech.text.clean import clean_text |
|
from fish_speech.utils import RankedLogger |
|
from fish_speech.utils.braceexpand import braceexpand |
|
|
|
log = RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
|
def split_by_rank_worker(files): |
|
|
|
|
|
|
|
total_devices = 1 |
|
if is_initialized(): |
|
total_devices = get_world_size() |
|
|
|
worker_info = get_worker_info() |
|
if worker_info is not None: |
|
total_devices *= worker_info.num_workers |
|
|
|
if len(files) < total_devices: |
|
|
|
files = files * (total_devices // len(files) + 1) |
|
|
|
|
|
if is_initialized(): |
|
files = files[get_rank() :: get_world_size()] |
|
|
|
|
|
if worker_info is not None: |
|
files = files[worker_info.id :: worker_info.num_workers] |
|
|
|
return files |
|
|
|
|
|
class AutoTextSemanticInstructionDataset(IterableDataset): |
|
""" |
|
Auto Augment Dataset by Speaker |
|
|
|
1. Random concatenate multiple sentences from the same speaker to form a longer sentence |
|
2. Automatically normalize the text |
|
|
|
For interactive mode, we use the following format (multiple sequences): |
|
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s> |
|
|
|
For non-interactive mode, we use the following format (one long sequence): |
|
<s> [INST] text [/INST] ... </s> |
|
""" |
|
|
|
def __init__( |
|
self, |
|
proto_files: list[str], |
|
seed: int = 42, |
|
interactive_prob: float = 0.5, |
|
max_length: int = 1024, |
|
tokenizer: AutoTokenizer = None, |
|
use_speaker: bool | float = True, |
|
causal: bool = True, |
|
num_codebooks: Optional[int] = None, |
|
skip_text_prob: float = 0.0, |
|
): |
|
""" |
|
Args: |
|
proto_files: proto buf files if using local data |
|
seed: random seed |
|
interactive_prob: probability to use interactive mode |
|
max_length: max length of the text |
|
tokenizer: tokenizer |
|
use_speaker: include speaker information in the prompt |
|
causal: use causal sampling when using local data, disable will lead to random sampling |
|
num_codebooks: number of codebooks, if None, it will be automatically detected |
|
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode |
|
""" |
|
|
|
super().__init__() |
|
|
|
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" |
|
|
|
self.seed = seed |
|
self.max_length = max_length |
|
self.tokenizer = tokenizer |
|
self.interactive_prob = interactive_prob |
|
self.use_speaker = use_speaker |
|
self.proto_files = proto_files |
|
self.causal = causal |
|
self.num_codebooks = num_codebooks |
|
self.skip_text_prob = skip_text_prob |
|
|
|
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") |
|
self.groups = None |
|
|
|
def init_mock_data_server(self): |
|
if self.groups is not None: |
|
return |
|
|
|
|
|
expanded_proto_files = [] |
|
for filename in self.proto_files: |
|
for i in braceexpand(filename): |
|
i = Path(i) |
|
if i.is_file(): |
|
expanded_proto_files.append(i) |
|
elif i.is_dir(): |
|
expanded_proto_files.extend(i.rglob("*.proto")) |
|
expanded_proto_files.extend(i.rglob("*.protos")) |
|
else: |
|
raise ValueError(f"{i} is not a file or directory") |
|
|
|
expanded_proto_files = sorted(expanded_proto_files) |
|
Random(self.seed).shuffle(expanded_proto_files) |
|
|
|
self.groups = [] |
|
shard_proto_files = split_by_rank_worker(expanded_proto_files) |
|
log.info( |
|
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" |
|
) |
|
|
|
count = 0 |
|
for filename in shard_proto_files: |
|
with open(filename, "rb") as f: |
|
for text_data in read_pb_stream(f): |
|
self.groups.append(text_data) |
|
count += 1 |
|
|
|
log.info(f"Read total {count} groups of data") |
|
|
|
|
|
Random(self.seed).shuffle(self.groups) |
|
self.group_weights = [len(i.sentences) for i in self.groups] |
|
|
|
def __iter__(self): |
|
while True: |
|
yield self.augment() |
|
|
|
def tokenize_sentence(self, sentence: str): |
|
sentence = clean_text(sentence) |
|
tokens = self.tokenizer.encode( |
|
f"{sentence}", |
|
max_length=10**6, |
|
add_special_tokens=False, |
|
truncation=False, |
|
) |
|
return sentence, len(tokens) |
|
|
|
def sample_data(self): |
|
if self.groups is None: |
|
self.init_mock_data_server() |
|
|
|
|
|
num_samples = self.max_length // 20 |
|
|
|
|
|
group = random.choices(self.groups, weights=self.group_weights, k=1)[0] |
|
|
|
if self.causal: |
|
|
|
if num_samples >= len(group.sentences): |
|
samples = group.sentences |
|
else: |
|
begin = random.randint(0, len(group.sentences) - num_samples) |
|
samples = group.sentences[begin : begin + num_samples] |
|
else: |
|
samples = random.choices( |
|
group.sentences, k=min(num_samples, len(group.sentences)) |
|
) |
|
|
|
return SampledData( |
|
source=group.source, |
|
name=group.name, |
|
samples=samples, |
|
) |
|
|
|
def augment(self): |
|
final_text, final_semantic = [], [] |
|
response = self.sample_data() |
|
if len(response.samples) == 0: |
|
|
|
return None |
|
|
|
samples = list(response.samples) |
|
idx = 0 |
|
use_interactive = random.random() < self.interactive_prob |
|
|
|
if use_interactive is False: |
|
|
|
a = torch.tensor([0], dtype=torch.float32) |
|
torch.nn.init.trunc_normal_( |
|
a, |
|
mean=self.max_length // 2, |
|
std=self.max_length // 4, |
|
a=10, |
|
b=self.max_length, |
|
) |
|
remaining_tokens = a.long().item() - 4 |
|
else: |
|
remaining_tokens = self.max_length |
|
|
|
|
|
if isinstance(self.use_speaker, float): |
|
use_speaker = random.random() < self.use_speaker |
|
else: |
|
use_speaker = self.use_speaker |
|
|
|
all_tokens, all_labels = [], [] |
|
while remaining_tokens > 0 and len(samples) > 0: |
|
sentence = samples.pop(0) |
|
|
|
text = random.choice(sentence.texts) |
|
text, length = self.tokenize_sentence(text) |
|
remaining_tokens -= length + len(sentence.semantics[0].values) |
|
|
|
if use_interactive is False: |
|
final_text.append(text) |
|
final_semantic.append(sentence.semantics) |
|
else: |
|
|
|
|
|
tokens, labels = self.pack_sentences( |
|
sentences=[text], |
|
semantics=[sentence.semantics], |
|
speaker=response.name if use_speaker else None, |
|
skip_text=random.random() < self.skip_text_prob, |
|
) |
|
|
|
all_tokens.append(tokens) |
|
all_labels.append(labels) |
|
|
|
idx += 1 |
|
|
|
if use_interactive is False: |
|
tokens, labels = self.pack_sentences( |
|
final_text, |
|
semantics=final_semantic, |
|
speaker=response.name if use_speaker else None, |
|
) |
|
all_tokens.append(tokens) |
|
all_labels.append(labels) |
|
|
|
tokens = torch.cat(all_tokens, dim=1) |
|
labels = torch.cat(all_labels, dim=1) |
|
|
|
|
|
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" |
|
|
|
data = {"tokens": tokens, "labels": labels} |
|
|
|
return data |
|
|
|
def pack_sentences( |
|
self, |
|
sentences: list[str], |
|
semantics: list, |
|
speaker: Optional[str] = None, |
|
skip_text: bool = False, |
|
): |
|
if speaker is None: |
|
speaker = "assistant" |
|
|
|
cated_sentences = " ".join(sentences) |
|
if skip_text: |
|
cated_sentences = "<|skip_text|>" |
|
|
|
final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" |
|
final_text = final_text + f"<|im_start|>{speaker}\n" |
|
|
|
encoded = self.tokenizer.encode( |
|
final_text, |
|
add_special_tokens=False, |
|
truncation=False, |
|
max_length=10**6, |
|
) |
|
semantic_length = sum([len(i[0].values) for i in semantics]) |
|
prompt_length = len(encoded) |
|
num_codebooks = ( |
|
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks |
|
) |
|
|
|
|
|
tokens = ( |
|
encoded |
|
+ [self.semantic_token_id] * semantic_length |
|
+ self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) |
|
) |
|
|
|
|
|
codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] |
|
for segment in semantics: |
|
for book_idx, book in zip(range(num_codebooks), segment): |
|
for j in book.values: |
|
codes[book_idx].append(int(j) + 1) |
|
|
|
for book in codes: |
|
book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) |
|
|
|
tokens = [tokens] + codes |
|
|
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
labels = tokens.clone() |
|
|
|
if skip_text: |
|
|
|
torch.fill_(labels, -100) |
|
return tokens, labels |
|
|
|
|
|
|
|
labels[1:, :prompt_length] = -100 |
|
|
|
tokens = tokens[:, :-1] |
|
labels = labels[:, 1:] |
|
|
|
|
|
assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() |
|
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() |
|
|
|
return tokens, labels |
|
|
|
|
|
@dataclass |
|
class TextDataCollator: |
|
tokenizer: AutoTokenizer |
|
max_length: int = 1024 |
|
|
|
def __call__(self, examples): |
|
if "negative_tokens" in examples: |
|
positive_examples = [] |
|
negative_examples = [] |
|
|
|
for i in examples: |
|
positive_examples.append( |
|
{ |
|
"tokens": i["tokens"], |
|
"labels": i["labels"], |
|
} |
|
) |
|
negative_examples.append( |
|
{ |
|
"tokens": i["negative_tokens"], |
|
"labels": i["negative_labels"], |
|
} |
|
) |
|
|
|
examples = positive_examples + negative_examples |
|
|
|
return self.batchify(examples) |
|
|
|
def batchify(self, examples, tokens_key="tokens", labels_key="labels"): |
|
tokens, attention_masks, labels = [], [], [] |
|
|
|
|
|
max_tokens_length = 0 |
|
for example in examples: |
|
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) |
|
max_tokens_length = min(max_tokens_length, self.max_length) |
|
|
|
for example in examples: |
|
_tokens = example[tokens_key][:, :max_tokens_length] |
|
_labels = example[labels_key][:, :max_tokens_length] |
|
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) |
|
tokens_length = _tokens.size(1) |
|
_attention_mask[:tokens_length] = False |
|
|
|
assert tokens_length == _labels.size( |
|
1 |
|
), f"{tokens_length} != {_labels.size(1)}" |
|
|
|
if tokens_length < max_tokens_length: |
|
_tokens = F.pad( |
|
_tokens, |
|
(0, max_tokens_length - tokens_length), |
|
value=self.tokenizer.eos_token_id, |
|
) |
|
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID |
|
_labels = F.pad( |
|
_labels, (0, max_tokens_length - _labels.size(1)), value=-100 |
|
) |
|
|
|
tokens.append(_tokens) |
|
attention_masks.append(_attention_mask) |
|
labels.append(_labels) |
|
|
|
tokens = torch.stack(tokens, dim=0) |
|
attention_masks = torch.stack(attention_masks, dim=0) |
|
labels = torch.stack(labels, dim=0) |
|
|
|
return { |
|
"inputs": tokens, |
|
"attention_masks": attention_masks, |
|
"labels": labels, |
|
} |
|
|
|
|
|
class InterleaveDataset(IterableDataset): |
|
def __init__( |
|
self, |
|
datasets: list[IterableDataset], |
|
probabilities: list[float], |
|
seed: int = 42, |
|
): |
|
super().__init__() |
|
|
|
self.datasets = datasets |
|
self.probabilities = probabilities |
|
self.seed = seed |
|
|
|
def __iter__(self): |
|
rng = np.random.default_rng(self.seed) |
|
dataset_iterators = [iter(dataset) for dataset in self.datasets] |
|
|
|
while True: |
|
|
|
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) |
|
dataset_iterator = dataset_iterators[dataset_idx] |
|
|
|
try: |
|
yield next(dataset_iterator) |
|
except StopIteration: |
|
|
|
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) |
|
yield next(dataset_iterators[dataset_idx]) |
|
|
|
|
|
class SemanticDataModule(LightningDataModule): |
|
def __init__( |
|
self, |
|
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], |
|
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], |
|
batch_size: int = 32, |
|
tokenizer: AutoTokenizer = None, |
|
max_length: int = 1024, |
|
num_workers: int = 4, |
|
): |
|
super().__init__() |
|
|
|
self.train_dataset = train_dataset |
|
self.val_dataset = val_dataset |
|
self.batch_size = batch_size |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.num_workers = num_workers |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=TextDataCollator(self.tokenizer, self.max_length), |
|
num_workers=self.num_workers, |
|
persistent_workers=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=TextDataCollator(self.tokenizer, self.max_length), |
|
num_workers=self.num_workers, |
|
persistent_workers=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from tqdm import tqdm |
|
|
|
ds = AutoTextSemanticInstructionDataset( |
|
["data/protos"], |
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), |
|
use_speaker=False, |
|
interactive_prob=1.0, |
|
skip_text_prob=0.5, |
|
) |
|
|
|
for i in ds: |
|
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) |
|
|
|
|
|
break |
|
|