webui-colab-deps / k_diffusion_dpmpp.diff
netrunner-exe's picture
Upload k_diffusion_dpmpp.diff
e47d403
diff --git a/README.md b/README.md
index 4f7c92f..e386624 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
+# THIS IS A FORK
+
+Forked from https://github.com/crowsonkb/k-diffusion
+
+Changes:
+
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
+
# k-diffusion
An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
index 79b51ce..b41d0eb 100644
--- a/k_diffusion/external.py
+++ b/k_diffusion/external.py
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
def t_to_sigma(self, t):
t = t.float()
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ low_idx = t.floor().long()
+ high_idx = t.ceil().long()
+ w = t - low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
index f050f88..9f859d4 100644
--- a/k_diffusion/sampling.py
+++ b/k_diffusion/sampling.py
@@ -16,7 +16,7 @@ def append_zero(x):
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
- ramp = torch.linspace(0, 1, n)
+ ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
for i in range(len(orders)):
eps_cache = {}
- t, t_next = ts[i], ts[i + 1]
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
+ else:
+ t, t_next = ts[i], ts[i + 1]
+
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
+
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
+
+ t_min = min(sigma_fn(t_next), sigma_fn(t))
+ t_max = max(sigma_fn(t_next), sigma_fn(t))
+
if old_denoised is None or sigmas[i + 1] == 0:
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+ x = (t_min / t_max) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
- r = h_last / h
+
+ h_min = min(h_last, h)
+ h_max = max(h_last, h)
+ r = h_max / h_min
+
+ h_d = (h_max + h_min) / 2
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+ x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
+
old_denoised = denoised
return x
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
index 9afedb9..ce6014b 100644
--- a/k_diffusion/utils.py
+++ b/k_diffusion/utils.py
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
- return x[(...,) + (None,) * dims_to_append]
+ expanded = x[(...,) + (None,) * dims_to_append]
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
+ # https://github.com/pytorch/pytorch/issues/84364
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
def n_params(module):