Spaces:
Running
on
A10G
Running
on
A10G
from typing import Dict, Union | |
import torch | |
import torch.nn as nn | |
from ...util import append_dims, instantiate_from_config | |
from .denoiser_scaling import DenoiserScaling | |
from .discretizer import Discretization | |
class Denoiser(nn.Module): | |
def __init__(self, scaling_config: Dict): | |
super().__init__() | |
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) | |
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: | |
return sigma | |
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: | |
return c_noise | |
def forward( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: Dict, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
sigma = self.possibly_quantize_sigma(sigma) | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) | |
return ( | |
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out | |
+ input * c_skip | |
) | |
class DiscreteDenoiser(Denoiser): | |
def __init__( | |
self, | |
scaling_config: Dict, | |
num_idx: int, | |
discretization_config: Dict, | |
do_append_zero: bool = False, | |
quantize_c_noise: bool = True, | |
flip: bool = True, | |
): | |
super().__init__(scaling_config) | |
self.discretization: Discretization = instantiate_from_config( | |
discretization_config | |
) | |
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) | |
self.register_buffer("sigmas", sigmas) | |
self.quantize_c_noise = quantize_c_noise | |
self.num_idx = num_idx | |
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: | |
dists = sigma - self.sigmas[:, None] | |
return dists.abs().argmin(dim=0).view(sigma.shape) | |
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: | |
return self.sigmas[idx] | |
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: | |
return self.idx_to_sigma(self.sigma_to_idx(sigma)) | |
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: | |
if self.quantize_c_noise: | |
return self.sigma_to_idx(c_noise) | |
else: | |
return c_noise | |