File size: 1,365 Bytes
2f4febc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class EpsilonTarget():
    def __call__(self, x0, epsilon, logSNR, a, b):
        return epsilon

    def x0(self, noised, pred, logSNR, a, b):
        return (noised - pred * b) / a

    def epsilon(self, noised, pred, logSNR, a, b):
        return pred
    def noise_givenx0_noised(self, x0, noised , logSNR, a, b):
        return (noised - a * x0) / b
    def xt(self, x0, noise,  logSNR, a, b):
           
            return x0 * a + noise*b
class X0Target():
    def __call__(self, x0, epsilon, logSNR, a, b):
        return x0

    def x0(self, noised, pred, logSNR, a, b):
        return pred

    def epsilon(self, noised, pred, logSNR, a, b):
        return (noised - pred * a) / b

class VTarget():
    def __call__(self, x0, epsilon, logSNR, a, b):
        return a * epsilon - b * x0

    def x0(self, noised, pred, logSNR, a, b):
        squared_sum = a**2 + b**2
        return a/squared_sum * noised - b/squared_sum * pred

    def epsilon(self, noised, pred, logSNR, a, b):
        squared_sum = a**2 + b**2
        return b/squared_sum * noised + a/squared_sum * pred

class RectifiedFlowsTarget():
    def __call__(self, x0, epsilon, logSNR, a, b):
        return epsilon - x0

    def x0(self, noised, pred, logSNR, a, b):
        return noised - pred * b

    def epsilon(self, noised, pred, logSNR, a, b):
        return noised + pred * a