|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from models.tts.naturalspeech2.wavenet import WaveNet |
|
|
|
|
|
class DiffusionFlow(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
self.diff_estimator = WaveNet(cfg.wavenet) |
|
self.beta_min = cfg.beta_min |
|
self.beta_max = cfg.beta_max |
|
self.sigma = cfg.sigma |
|
self.noise_factor = cfg.noise_factor |
|
|
|
def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5): |
|
""" |
|
x: (B, 128, T) |
|
x_mask: (B, T), mask is 0 |
|
cond: (B, T, 512) |
|
spk_query_emb: (B, 32, 512) |
|
""" |
|
diffusion_step = torch.rand( |
|
x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False |
|
) |
|
diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset) |
|
xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step) |
|
|
|
flow_pred = self.diff_estimator( |
|
xt, x_mask, cond, diffusion_step, spk_query_emb |
|
) |
|
noise = z |
|
x0_pred = noise - flow_pred |
|
noise_pred = x + flow_pred |
|
diff_out = { |
|
"x0_pred": x0_pred, |
|
"noise_pred": noise_pred, |
|
"noise": noise, |
|
"flow_pred": flow_pred, |
|
} |
|
return diff_out |
|
|
|
@torch.no_grad() |
|
def forward_diffusion(self, x0, diffusion_step): |
|
""" |
|
x0: (B, 128, T) |
|
time_step: (B,) |
|
""" |
|
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1) |
|
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) |
|
xt = (1 - time_step) * x0 + time_step * z |
|
return xt, z |
|
|
|
@torch.no_grad() |
|
def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h): |
|
flow_pred = self.diff_estimator( |
|
xt, x_mask, cond, diffusion_step, spk_query_emb |
|
) |
|
dxt = h * flow_pred |
|
return dxt |
|
|
|
@torch.no_grad() |
|
def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb): |
|
h = 1.0 / n_timesteps |
|
xt = z |
|
for i in range(n_timesteps): |
|
t = (1.0 - (i + 0.5) * h) * torch.ones( |
|
z.shape[0], dtype=z.dtype, device=z.device |
|
) |
|
dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h) |
|
xt = xt - dxt |
|
return xt |
|
|