File size: 1,042 Bytes
2a00960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch

from scepter.modules.model.registry import NOISE_SCHEDULERS
from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler


@NOISE_SCHEDULERS.register_class()
class LinearScheduler(BaseNoiseScheduler):
    para_dict = {}

    def init_params(self):
        super().init_params()
        self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
        self.beta_max = self.cfg.get('BETA_MAX', 0.012)

    def betas_to_sigmas(self, betas):
        return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))

    def get_schedule(self):
        betas = torch.linspace(self.beta_min,
                               self.beta_max,
                               self.num_timesteps,
                               dtype=torch.float32)
        sigmas = self.betas_to_sigmas(betas)
        self._sigmas = sigmas
        self._betas = betas
        self._alphas = torch.sqrt(1 - sigmas**2)
        self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)