# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import torch from scepter.modules.model.registry import DIFFUSION_SAMPLERS from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler from scepter.modules.model.diffusion.util import _i def _i(tensor, t, x): """ Index tensor using t and format the output according to x. """ shape = (x.size(0), ) + (1, ) * (x.ndim - 1) if isinstance(t, torch.Tensor): t = t.to(tensor.device) return tensor[t].view(shape).to(x.device) @DIFFUSION_SAMPLERS.register_class('ddim') class DDIMSampler(BaseDiffusionSampler): def init_params(self): super().init_params() self.eta = self.cfg.get('ETA', 0.) self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE', 'trailing') def preprare_sampler(self, noise, steps=20, scheduler_ins=None, prediction_type='', sigmas=None, betas=None, alphas=None, callback_fn=None, **kwargs): output = super().preprare_sampler(noise, steps, scheduler_ins, prediction_type, sigmas, betas, alphas, callback_fn, **kwargs) sigmas = output.sigmas sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5 sigmas_vp[sigmas == float('inf')] = 1. output.add_custom_field('sigmas_vp', sigmas_vp) return output def step(self, sampler_output): x_t = sampler_output.x_t step = sampler_output.step t = sampler_output.ts[step] sigmas_vp = sampler_output.sigmas_vp.to(x_t.device) alpha_init = _i(sampler_output.alphas_init, step, x_t[:1]) sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1]) x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init) noise_factor = self.eta * (sigmas_vp[step + 1]**2 / sigmas_vp[step]**2 * (1 - (1 - sigmas_vp[step]**2) / (1 - sigmas_vp[step + 1]**2))) d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step] x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \ (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d sampler_output.x_0 = x if sigmas_vp[step + 1] > 0: x += noise_factor * torch.randn_like(x) sampler_output.x_t = x sampler_output.step += 1 sampler_output.msg = f'step {step}' return sampler_output