|
|
|
|
|
|
|
|
|
@@ -1,3 +1,12 @@ |
|
+# THIS IS A FORK |
|
+ |
|
+Forked from https://github.com/crowsonkb/k-diffusion |
|
+ |
|
+Changes: |
|
+ |
|
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 |
|
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion |
|
+ |
|
# k-diffusion |
|
|
|
An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well. |
|
|
|
|
|
|
|
|
|
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module): |
|
|
|
def t_to_sigma(self, t): |
|
t = t.float() |
|
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() |
|
+ low_idx = t.floor().long() |
|
+ high_idx = t.ceil().long() |
|
+ w = t - low_idx if t.device.type == 'mps' else t.frac() |
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] |
|
return log_sigma.exp() |
|
|
|
|
|
|
|
|
|
|
|
@@ -16,7 +16,7 @@ def append_zero(x): |
|
|
|
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): |
|
"""Constructs the noise schedule of Karras et al. (2022).""" |
|
- ramp = torch.linspace(0, 1, n) |
|
+ ramp = torch.linspace(0, 1, n, device=device) |
|
min_inv_rho = sigma_min ** (1 / rho) |
|
max_inv_rho = sigma_max ** (1 / rho) |
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
|
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module): |
|
|
|
for i in range(len(orders)): |
|
eps_cache = {} |
|
- t, t_next = ts[i], ts[i + 1] |
|
+ |
|
+ # MacOS fix |
|
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built(): |
|
+ t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone() |
|
+ else: |
|
+ t, t_next = ts[i], ts[i + 1] |
|
+ |
|
if eta: |
|
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) |
|
t_next_ = torch.minimum(t_end, self.t(sd)) |
|
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, |
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
|
s_in = x.new_ones([x.shape[0]]) |
|
sigma_fn = lambda t: t.neg().exp() |
|
- t_fn = lambda sigma: sigma.log().neg() |
|
+ |
|
+ # MacOS fix |
|
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built(): |
|
+ t_fn = lambda sigma: sigma.detach().clone().log().neg() |
|
+ else: |
|
+ t_fn = lambda sigma: sigma.log().neg() |
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N |
|
extra_args = {} if extra_args is None else extra_args |
|
s_in = x.new_ones([x.shape[0]]) |
|
sigma_fn = lambda t: t.neg().exp() |
|
- t_fn = lambda sigma: sigma.log().neg() |
|
+ |
|
+ # MacOS fix |
|
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built(): |
|
+ t_fn = lambda sigma: sigma.detach().clone().log().neg() |
|
+ else: |
|
+ t_fn = lambda sigma: sigma.log().neg() |
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No |
|
extra_args = {} if extra_args is None else extra_args |
|
s_in = x.new_ones([x.shape[0]]) |
|
sigma_fn = lambda t: t.neg().exp() |
|
- t_fn = lambda sigma: sigma.log().neg() |
|
+ |
|
+ # MacOS fix |
|
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built(): |
|
+ t_fn = lambda sigma: sigma.detach().clone().log().neg() |
|
+ else: |
|
+ t_fn = lambda sigma: sigma.log().neg() |
|
+ |
|
old_denoised = None |
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No |
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) |
|
h = t_next - t |
|
+ |
|
+ t_min = min(sigma_fn(t_next), sigma_fn(t)) |
|
+ t_max = max(sigma_fn(t_next), sigma_fn(t)) |
|
+ |
|
if old_denoised is None or sigmas[i + 1] == 0: |
|
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised |
|
+ x = (t_min / t_max) * x - (-h).expm1() * denoised |
|
else: |
|
h_last = t - t_fn(sigmas[i - 1]) |
|
- r = h_last / h |
|
+ |
|
+ h_min = min(h_last, h) |
|
+ h_max = max(h_last, h) |
|
+ r = h_max / h_min |
|
+ |
|
+ h_d = (h_max + h_min) / 2 |
|
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
|
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d |
|
+ x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d |
|
+ |
|
old_denoised = denoised |
|
return x |
|
|
|
|
|
|
|
|
|
@@ -42,7 +42,10 @@ def append_dims(x, target_dims): |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') |
|
- return x[(...,) + (None,) * dims_to_append] |
|
+ expanded = x[(...,) + (None,) * dims_to_append] |
|
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. |
|
+ # https://github.com/pytorch/pytorch/issues/84364 |
|
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded |
|
|
|
|
|
def n_params(module): |
|
|