Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,314 Bytes
2a00960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from collections import OrderedDict
import torch
from tqdm import trange
from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
NOISE_SCHEDULERS)
from scepter.modules.utils.config import Config, dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
@DIFFUSIONS.register_class()
class ACEDiffusion(object):
para_dict = {
'NOISE_SCHEDULER': {},
'SAMPLER_SCHEDULER': {},
'MIN_SNR_GAMMA': {
'value': None,
'description': 'The minimum SNR gamma value for the loss function.'
},
'PREDICTION_TYPE': {
'value': 'eps',
'description':
'The type of prediction to use for the loss function.'
}
}
def __init__(self, cfg, logger=None):
super(ACEDiffusion, self).__init__()
self.logger = logger
self.cfg = cfg
self.init_params()
def init_params(self):
self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
logger=self.logger)
self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
logger=self.logger)
self.num_timesteps = self.noise_scheduler.num_timesteps
if self.cfg.have('WORK_DIR') and we.rank == 0:
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
'noise_schedule.png')
with FS.put_to(schedule_visualization) as local_path:
self.noise_scheduler.plot_noise_sampling_map(local_path)
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
'sampler_schedule.png')
with FS.put_to(schedule_visualization) as local_path:
self.sampler_scheduler.plot_noise_sampling_map(local_path)
def sample(self,
noise,
model,
model_kwargs={},
steps=20,
sampler=None,
use_dynamic_cfg=False,
guide_scale=None,
guide_rescale=None,
show_progress=False,
return_intermediate=None,
intermediate_callback=None,
**kwargs):
assert isinstance(steps, (int, torch.LongTensor))
assert return_intermediate in (None, 'x0', 'xt')
assert isinstance(sampler, (str, dict, Config))
intermediates = []
def callback_fn(x_t, t, sigma=None, alpha=None):
timestamp = t
t = t.repeat(len(x_t)).round().long().to(x_t.device)
sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
if guide_scale is None or guide_scale == 1.0:
out = model(x=x_t, t=t, **model_kwargs)
else:
if use_dynamic_cfg:
guidance_scale = 1 + guide_scale * (
(1 - math.cos(math.pi * (
(steps - timestamp.item()) / steps)**5.0)) / 2)
else:
guidance_scale = guide_scale
y_out = model(x=x_t, t=t, **model_kwargs[0])
u_out = model(x=x_t, t=t, **model_kwargs[1])
out = u_out + guidance_scale * (y_out - u_out)
if guide_rescale is not None and guide_rescale > 0.0:
ratio = (
y_out.flatten(1).std(dim=1) /
(out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
(y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
if self.prediction_type == 'x0':
x0 = out
elif self.prediction_type == 'eps':
x0 = (x_t - sigma * out) / alpha
elif self.prediction_type == 'v':
x0 = alpha * x_t - sigma * out
else:
raise NotImplementedError(
f'prediction_type {self.prediction_type} not implemented')
return x0
sampler_ins = self.get_sampler(sampler)
# this is ignored for schnell
sampler_output = sampler_ins.preprare_sampler(
noise,
steps=steps,
prediction_type=self.prediction_type,
scheduler_ins=self.sampler_scheduler,
callback_fn=callback_fn)
for _ in trange(steps, disable=not show_progress):
trange.desc = sampler_output.msg
sampler_output = sampler_ins.step(sampler_output)
if return_intermediate == 'x_0':
intermediates.append(sampler_output.x_0)
elif return_intermediate == 'x_t':
intermediates.append(sampler_output.x_t)
if intermediate_callback is not None:
intermediate_callback(intermediates[-1])
return (sampler_output.x_0, intermediates
) if return_intermediate is not None else sampler_output.x_0
def loss(self,
x_0,
model,
model_kwargs={},
reduction='mean',
noise=None,
**kwargs):
# use noise scheduler to add noise
if noise is None:
noise = torch.randn_like(x_0)
schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
out = model(x=x_t, t=t, **model_kwargs)
# mse loss
target = {
'eps': noise,
'x0': x_0,
'v': alpha * noise - sigma * x_0
}[self.prediction_type]
loss = (out - target).pow(2)
if reduction == 'mean':
loss = loss.flatten(1).mean(dim=1)
if self.min_snr_gamma is not None:
alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
snrs = (alphas / sigmas).clamp(min=1e-20)
min_snrs = snrs.clamp(max=self.min_snr_gamma)
weights = min_snrs / snrs
else:
weights = 1
loss = loss * weights
return loss
def get_sampler(self, sampler):
if isinstance(sampler, str):
if sampler not in DIFFUSION_SAMPLERS.class_map:
if self.logger is not None:
self.logger.info(
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
)
else:
print(
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
)
return None
sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
logger=self.logger)
elif isinstance(sampler, (Config, dict, OrderedDict)):
if isinstance(sampler, (dict, OrderedDict)):
sampler = Config(
cfg_dict={k.upper(): v
for k, v in dict(sampler).items()},
load=False)
sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
else:
raise NotImplementedError
return sampler_ins
def __repr__(self) -> str:
return f'{self.__class__.__name__}' + ' ' + super().__repr__()
@staticmethod
def get_config_template():
return dict_to_yaml('DIFFUSIONS',
__class__.__name__,
ACEDiffusion.para_dict,
set_name=True) |