heheyas
init
cfb7702
import torch
import torch.nn as nn
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log