|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import time |
|
from .valle_ar_trainer import ValleARTrainer, make_pad_mask |
|
|
|
|
|
class ValleNARTrainer(ValleARTrainer): |
|
def __init__(self, args=None, cfg=None): |
|
super().__init__(args, cfg) |
|
print("simple NAR") |
|
self.top1_accuracies = { |
|
1: [], |
|
2: [], |
|
3: [], |
|
4: [], |
|
5: [], |
|
6: [], |
|
7: [], |
|
} |
|
self.top5_accuracies = { |
|
1: [], |
|
2: [], |
|
3: [], |
|
4: [], |
|
5: [], |
|
6: [], |
|
7: [], |
|
} |
|
self.top10_accuracies = { |
|
1: [], |
|
2: [], |
|
3: [], |
|
4: [], |
|
5: [], |
|
6: [], |
|
7: [], |
|
} |
|
|
|
def _build_model(self): |
|
from .valle_nar import ValleNAR |
|
|
|
return ValleNAR(**self.cfg.model) |
|
|
|
def _train_step(self, batch): |
|
|
|
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') |
|
speech: [B, T] |
|
speech_len: [B] |
|
phone_ids: [B, T] |
|
phone_lens: [B] |
|
""" |
|
device = self.accelerator.device |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.to(device) |
|
|
|
with torch.no_grad(): |
|
if self.cfg.use_speechtokenizer: |
|
|
|
|
|
vq_id = self.codec_encoder.encode( |
|
batch["speech"].unsqueeze(1) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) |
|
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( |
|
0, 1 |
|
) |
|
|
|
|
|
|
|
|
|
batch["speech"] = vq_id |
|
batch["speech_len"] = batch["speech_len"] // 320 |
|
assert batch["speech_len"].max() <= batch["speech"].shape[-1] |
|
|
|
phone_mask = 1 - make_pad_mask( |
|
batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False |
|
).to(torch.long) |
|
speech_mask = 1 - make_pad_mask( |
|
batch["speech_len"], max_len=batch["speech"].size(-1) |
|
).to(torch.long) |
|
|
|
np.random.seed(int(time.time()) - 5 * self.accelerator.process_index) |
|
|
|
if hasattr(self.cfg.train, "dropout"): |
|
dropout = self.cfg.train.dropout |
|
else: |
|
dropout = 0.0 |
|
|
|
out = self.model( |
|
phone_ids=batch["phone_ids"], |
|
phone_mask=phone_mask, |
|
target_ids=batch["speech"], |
|
target_mask=speech_mask, |
|
dropout=dropout, |
|
) |
|
loss = out.loss |
|
|
|
self.accelerator.log( |
|
{f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc}, |
|
step=self.step, |
|
) |
|
self.accelerator.log( |
|
{f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc}, |
|
step=self.step, |
|
) |
|
self.accelerator.log( |
|
{f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc}, |
|
step=self.step, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
def _test_step(self, batch): |
|
|
|
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') |
|
speech: [B, T] |
|
speech_len: [B] |
|
phone_ids: [B, T] |
|
phone_lens: [B] |
|
""" |
|
import torchaudio |
|
|
|
device = self.accelerator.device |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.to(device) |
|
with torch.no_grad(): |
|
if self.cfg.use_speechtokenizer: |
|
|
|
|
|
vq_id = self.codec_encoder.encode( |
|
batch["speech"].unsqueeze(1) |
|
) |
|
|
|
|
|
|
|
|
|
else: |
|
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) |
|
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( |
|
0, 1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch["speech"] = vq_id |
|
|
|
|
|
if self.cfg.use_speechtokenizer: |
|
recovered_audio = self.codec_encoder.decode(vq_id) |
|
else: |
|
recovered_audio = self.codec_encoder.decode( |
|
[(vq_id.transpose(0, 1), None)] |
|
) |
|
torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) |
|
self.model.eval() |
|
out_vq_ids = self.model.sample_hf( |
|
phone_ids=batch["phone_ids"][:1], |
|
prompt_ids=batch["speech"][:, :1, :150], |
|
first_stage_ids=batch["speech"][0, :1, 150:], |
|
) |
|
|
|
|
|
|
|
|
|
if self.cfg.use_speechtokenizer: |
|
recovered_audio = self.codec_encoder.decode(out_vq_ids) |
|
else: |
|
recovered_audio = self.codec_encoder.decode( |
|
[(out_vq_ids.transpose(0, 1)[:1], None)] |
|
) |
|
torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) |
|
breakpoint() |
|
|