NoCrypt commited on
Commit
1a8db9e
·
1 Parent(s): b1bcf10

Upload 0001-karras-v2-experimental.patch

Browse files
patch/0001-karras-v2-experimental.patch ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ From 36078b25801787f0a0f145143637f46d33d8c389 Mon Sep 17 00:00:00 2001
2
+ From: Ashen <git123@gmail.com>
3
+ Date: Fri, 7 Apr 2023 22:04:35 -0700
4
+ Subject: [PATCH] karras v2 experimental
5
+
6
+ ---
7
+ k_diffusion/sampling.py | 36 ++++++++++++++++++++++++++++++++++++
8
+ 1 file changed, 36 insertions(+)
9
+
10
+ diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
11
+ index f050f88..4d5df2a 100644
12
+ --- a/k_diffusion/sampling.py
13
+ +++ b/k_diffusion/sampling.py
14
+ @@ -605,3 +605,39 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
15
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
16
+ old_denoised = denoised
17
+ return x
18
+ +
19
+ +
20
+ +@torch.no_grad()
21
+ +def sample_dpmpp_2m_test(model, x, sigmas, extra_args=None, callback=None, disable=None):
22
+ + """DPM-Solver++(2M)."""
23
+ + extra_args = {} if extra_args is None else extra_args
24
+ + s_in = x.new_ones([x.shape[0]])
25
+ + sigma_fn = lambda t: t.neg().exp()
26
+ + t_fn = lambda sigma: sigma.log().neg()
27
+ + old_denoised = None
28
+ +
29
+ + for i in trange(len(sigmas) - 1, disable=disable):
30
+ + denoised = model(x, sigmas[i] * s_in, **extra_args)
31
+ + if callback is not None:
32
+ + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
33
+ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
34
+ + h = t_next - t
35
+ +
36
+ + t_min = min(sigma_fn(t_next), sigma_fn(t))
37
+ + t_max = max(sigma_fn(t_next), sigma_fn(t))
38
+ +
39
+ + if old_denoised is None or sigmas[i + 1] == 0:
40
+ + x = (t_min / t_max) * x - (-h).expm1() * denoised
41
+ + else:
42
+ + h_last = t - t_fn(sigmas[i - 1])
43
+ +
44
+ + h_min = min(h_last, h)
45
+ + h_max = max(h_last, h)
46
+ + r = h_max / h_min
47
+ +
48
+ + h_d = (h_max + h_min) / 2
49
+ + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
50
+ + x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
51
+ +
52
+ + old_denoised = denoised
53
+ + return x
54
+
55
+ --
56
+ 2.40.0
57
+