|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
class VALLECollator: |
|
def __init__(self, cfg=None): |
|
self.cfg = cfg |
|
|
|
def __call__(self, batch): |
|
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') |
|
speech: [B, T] |
|
speech_len: [B] |
|
phone_ids: [B, T] |
|
phone_lens: [B] |
|
""" |
|
assert len(batch) != 0, "batch is empty before None checking" |
|
batch = [b for b in batch if b is not None] |
|
assert len(batch) != 0, "batch is empty after None checking" |
|
packed_batch_features = {} |
|
|
|
|
|
def process_tensor(data, dtype=torch.float32): |
|
if isinstance(data, torch.Tensor): |
|
return data.detach() |
|
else: |
|
return torch.tensor(data, dtype=dtype) |
|
|
|
|
|
speeches = [process_tensor(b["speech"]) for b in batch] |
|
packed_batch_features["speech_len"] = torch.tensor( |
|
[len(s) for s in speeches], dtype=torch.long |
|
) |
|
packed_batch_features["speech"] = pad_sequence( |
|
speeches, batch_first=True, padding_value=0 |
|
) |
|
|
|
|
|
phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch] |
|
packed_batch_features["phone_lens"] = torch.tensor( |
|
[len(phone) for phone in phones], dtype=torch.long |
|
) |
|
packed_batch_features["phone_ids"] = pad_sequence( |
|
phones, batch_first=True, padding_value=0 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return packed_batch_features |
|
|