Spaces:
Running
Running
import torch | |
from utils.manifolds import Sphere, geodesic | |
from torch.func import vjp, jvp, vmap, jacrev | |
class DDPMLoss: | |
def __init__( | |
self, | |
scheduler, | |
cond_drop_rate=0.0, | |
conditioning_key="label", | |
): | |
self.scheduler = scheduler | |
self.cond_drop_rate = cond_drop_rate | |
self.conditioning_key = conditioning_key | |
def __call__(self, preconditioning, network, batch, generator=None): | |
x_0 = batch["x_0"] | |
batch_size = x_0.shape[0] | |
device = x_0.device | |
t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) | |
gamma = self.scheduler(t).unsqueeze(-1) | |
n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) | |
y = torch.sqrt(gamma) * x_0 + torch.sqrt(1 - gamma) * n | |
batch["y"] = y | |
conditioning = batch[self.conditioning_key] | |
if conditioning is not None and self.cond_drop_rate > 0: | |
drop_mask = ( | |
torch.rand(batch_size, device=device, generator=generator) | |
< self.cond_drop_rate | |
) | |
conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) | |
batch[self.conditioning_key] = conditioning.detach() | |
batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) | |
D_n = preconditioning(network, batch) | |
loss = (D_n - n) ** 2 | |
return loss | |
class FlowMatchingLoss: | |
def __init__( | |
self, | |
scheduler, | |
cond_drop_rate=0.0, | |
conditioning_key="label", | |
): | |
self.scheduler = scheduler | |
self.cond_drop_rate = cond_drop_rate | |
self.conditioning_key = conditioning_key | |
def __call__(self, preconditioning, network, batch, generator=None): | |
x_0 = batch["x_0"] | |
batch_size = x_0.shape[0] | |
device = x_0.device | |
t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) | |
gamma = self.scheduler(t).unsqueeze(-1) | |
n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) | |
y = gamma * x_0 + (1 - gamma) * n | |
batch["y"] = y | |
conditioning = batch[self.conditioning_key] | |
if conditioning is not None and self.cond_drop_rate > 0: | |
drop_mask = ( | |
torch.rand(batch_size, device=device, generator=generator) | |
< self.cond_drop_rate | |
) | |
conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) | |
batch[self.conditioning_key] = conditioning.detach() | |
batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) | |
D_n = preconditioning(network, batch) | |
loss = (D_n - (x_0 - n)) ** 2 | |
return loss | |
class RiemannianFlowMatchingLoss: | |
def __init__( | |
self, | |
scheduler, | |
cond_drop_rate=0.0, | |
conditioning_key="label", | |
): | |
self.scheduler = scheduler | |
self.cond_drop_rate = cond_drop_rate | |
self.conditioning_key = conditioning_key | |
self.manifold = Sphere() | |
self.manifold_dim = 3 | |
def __call__(self, preconditioning, network, batch, generator=None): | |
x_1 = batch["x_0"] | |
batch_size = x_1.shape[0] | |
device = x_1.device | |
t = torch.rand(batch_size, device=device, dtype=x_1.dtype, generator=generator) | |
gamma = self.scheduler(t).unsqueeze(-1) | |
x_0 = self.manifold.random_base(x_1.shape[0], self.manifold_dim).to(x_1) | |
def cond_u(x0, x1, t): | |
path = geodesic(self.manifold, x0, x1) | |
x_t, u_t = jvp(path, (t,), (torch.ones_like(t).to(t),)) | |
return x_t, u_t | |
y, u_t = vmap(cond_u)(x_0, x_1, gamma) | |
y = y.reshape(batch_size, self.manifold_dim) | |
u_t = u_t.reshape(batch_size, self.manifold_dim) | |
batch["y"] = y | |
conditioning = batch[self.conditioning_key] | |
if conditioning is not None and self.cond_drop_rate > 0: | |
drop_mask = ( | |
torch.rand(batch_size, device=device, generator=generator) | |
< self.cond_drop_rate | |
) | |
conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) | |
batch[self.conditioning_key] = conditioning.detach() | |
batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) | |
D_n = preconditioning(network, batch) | |
diff = D_n - u_t | |
loss = self.manifold.inner(y, diff, diff).mean() / self.manifold_dim | |
return loss | |
class VonFisherLoss: | |
def __init__(self, dim=3): | |
self.dim = dim | |
def __call__(self, preconditioning, network, batch, generator=None): | |
x = batch["x_0"] | |
mu, kappa = preconditioning(network, batch) | |
loss = ( | |
torch.log((kappa + 1e-8)) | |
- torch.log(torch.tensor(4 * torch.pi, dtype=kappa.dtype)) | |
- log_sinh(kappa) | |
+ kappa * (mu * x).sum(dim=-1, keepdim=True) | |
) | |
return -loss | |
class VonFisherMixtureLoss: | |
def __init__(self, dim=3): | |
self.dim = dim | |
def __call__(self, preconditioning, network, batch, generator=None): | |
x = batch["x_0"] | |
mu_mixture, kappa_mixture, weights = preconditioning(network, batch) | |
loss = 0 | |
for i in range(mu_mixture.shape[1]): | |
mu = mu_mixture[:, i] | |
kappa = kappa_mixture[:, i].unsqueeze(1) | |
loss += weights[:, i].unsqueeze(1) * ( | |
kappa | |
* torch.exp(kappa * ((mu * x).sum(dim=-1, keepdim=True) - 1)) | |
/ (1e-8 + 2 * torch.pi * (1 - torch.exp(-2 * kappa))) | |
) | |
return -torch.log(loss) | |
def log_sinh(x): | |
return x + torch.log(1e-8 + (1 - torch.exp(-2 * x)) / 2) | |