Spaces:
Sleeping
Sleeping
File size: 7,118 Bytes
2b5b9ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.insert(0, '.') # nopep8
from ldm.modules.losses_audio.vqperceptual import *
from ldm.modules.discriminator.multi_window_disc import Discriminator
class LPAPSWithDiscriminator(nn.Module):# 相比于contperceptual.py添加了MultiWindowDiscriminator
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPAPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"LPAPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
disc_win_num = 3
mel_disc_hidden_size = 128
self.discriminator_multi = Discriminator(time_lengths=[32, 64, 128][:disc_win_num],
freq_length=80, hidden_size=mel_disc_hidden_size, kernel=(3, 3),
cond_size=0, norm_type="in", reduction="stack")
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake_multi = self.discriminator_multi(reconstructions.contiguous().squeeze(1).transpose(1, 2))
g_loss = -torch.mean(logits_fake)
g_loss_multi = -torch.mean(logits_fake_multi['y'])
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight_multi = self.calculate_adaptive_weight(nll_loss, g_loss_multi, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = d_weight_multi = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + d_weight_multi * disc_factor * g_loss_multi
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
"{}/g_loss_multi".format(split): g_loss_multi.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real_multi = self.discriminator_multi(inputs.contiguous().detach().squeeze(1).transpose(1, 2))
logits_fake_multi = self.discriminator_multi(reconstructions.contiguous().detach().squeeze(1).transpose(1, 2))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
d_loss_multi = disc_factor * self.disc_loss(logits_real_multi['y'], logits_fake_multi['y'])
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/disc_loss_multi".format(split): d_loss_multi.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss+d_loss_multi, log
|