wzhouxiff's picture
init
2890711
from typing import Dict, Union
import torch
import torch.nn as nn
from ...util import append_dims, instantiate_from_config
from .denoiser_scaling import DenoiserScaling
from .discretizer import Discretization
class Denoiser(nn.Module):
def __init__(self, scaling_config: Dict):
super().__init__()
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
class DiscreteDenoiser(Denoiser):
def __init__(
self,
scaling_config: Dict,
num_idx: int,
discretization_config: Dict,
do_append_zero: bool = False,
quantize_c_noise: bool = True,
flip: bool = True,
):
super().__init__(scaling_config)
self.discretization: Discretization = instantiate_from_config(
discretization_config
)
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise
self.num_idx = num_idx
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
dists = sigma - self.sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
return self.sigmas[idx]
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
if self.quantize_c_noise:
return self.sigma_to_idx(c_noise)
else:
return c_noise