File size: 1,591 Bytes
03a856a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np

def get_alpha(alphas_cumprod, timestep):
    timestep_lt_zero_mask = torch.lt(timestep, 0).to(alphas_cumprod.dtype)
    normal_alpha = alphas_cumprod[torch.clip(timestep, 0)]
    one_alpha = torch.ones_like(normal_alpha).to(normal_alpha.dtype).to(normal_alpha.dtype) 
    return normal_alpha * (1 - timestep_lt_zero_mask) + one_alpha * timestep_lt_zero_mask

def psuedo_velocity_wrt_noisy_and_timestep(noisy_images, noisy_images_pre, alphas_cumprod, timestep, timestep_prev):
    alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1).detach()
    beta_prod_t = 1 - alpha_prod_t
    alpha_prod_t_prev = get_alpha(alphas_cumprod, timestep_prev).view(-1, 1, 1, 1, 1).detach()
    beta_prod_t_prev = 1 - alpha_prod_t_prev

    a_s = (alpha_prod_t_prev ** (0.5)).to(noisy_images.dtype)
    a_t = (alpha_prod_t ** (0.5)).to(noisy_images.dtype)
    b_s = (beta_prod_t_prev ** (0.5)).to(noisy_images.dtype)
    b_t = (beta_prod_t ** (0.5)).to(noisy_images.dtype)

    psuedo_velocity = (noisy_images_pre - (
        a_s * a_t + b_s * b_t
    ) * noisy_images) / (
        b_s * a_t -  a_s * b_t
    )

    return psuedo_velocity

def origin_by_velocity_and_sample(velocity, noisy_images, alphas_cumprod, timestep):
    alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1).detach()
    beta_prod_t = 1 - alpha_prod_t
    a_t = (alpha_prod_t ** (0.5)).to(noisy_images.dtype)
    b_t = (beta_prod_t ** (0.5)).to(noisy_images.dtype)

    pred_original_sample = a_t * noisy_images - b_t * velocity
    return pred_original_sample