Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Protocol | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
from tqdm import trange | |
from .wn import WN | |
logger = logging.getLogger(__name__) | |
class VelocityField(Protocol): | |
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: | |
... | |
class Solver: | |
def __init__( | |
self, | |
method="midpoint", | |
nfe=32, | |
viz_name="solver", | |
viz_every=100, | |
mel_fn=None, | |
time_mapping_divisor=4, | |
verbose=False, | |
): | |
self.configurate_(nfe=nfe, method=method) | |
self.verbose = verbose | |
self.viz_every = viz_every | |
self.viz_name = viz_name | |
self._camera = None | |
self._mel_fn = mel_fn | |
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor) | |
def configurate_(self, nfe=None, method=None): | |
if nfe is None: | |
nfe = self.nfe | |
if method is None: | |
method = self.method | |
if nfe == 1 and method in ("midpoint", "rk4"): | |
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.") | |
method = "euler" | |
self.nfe = nfe | |
self.method = method | |
def time_mapping(self): | |
return self._time_mapping | |
def exponential_decay_mapping(t, n=4): | |
""" | |
Args: | |
n: target step | |
""" | |
def h(t, a): | |
return (a**t - 1) / (a - 1) | |
# Solve h(1/n) = 0.5 | |
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0)) | |
t = h(t, a=a) | |
return t | |
def _maybe_camera_snap(self, *, ψt, t): | |
camera = self._camera | |
if camera is not None: | |
if ψt.shape[1] == 1: | |
# Waveform, b 1 t, plot every 100 samples | |
plt.subplot(211) | |
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue") | |
if self._mel_fn is not None: | |
plt.subplot(212) | |
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0]) | |
plt.imshow(mel, origin="lower", interpolation="none") | |
elif ψt.shape[1] == 2: | |
# Complex | |
plt.subplot(121) | |
plt.imshow( | |
ψt.detach().cpu().numpy()[0, 0], | |
origin="lower", | |
interpolation="none", | |
) | |
plt.subplot(122) | |
plt.imshow( | |
ψt.detach().cpu().numpy()[0, 1], | |
origin="lower", | |
interpolation="none", | |
) | |
else: | |
# Spectrogram, b c t | |
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none") | |
ax = plt.gca() | |
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center") | |
camera.snap() | |
def _euler_step(t, ψt, dt, f: VelocityField): | |
return ψt + dt * f(t=t, ψt=ψt, dt=dt) | |
def _midpoint_step(t, ψt, dt, f: VelocityField): | |
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt) | |
def _rk4_step(t, ψt, dt, f: VelocityField): | |
k1 = f(t=t, ψt=ψt, dt=dt) | |
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt) | |
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt) | |
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt) | |
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 | |
def _step(self): | |
if self.method == "euler": | |
return self._euler_step | |
elif self.method == "midpoint": | |
return self._midpoint_step | |
elif self.method == "rk4": | |
return self._rk4_step | |
else: | |
raise ValueError(f"Unknown method: {self.method}") | |
def get_running_train_loop(self): | |
try: | |
# Lazy import | |
from ...utils.train_loop import TrainLoop | |
return TrainLoop.get_running_loop() | |
except ImportError: | |
return None | |
def visualizing(self): | |
loop = self.get_running_train_loop() | |
if loop is None: | |
return | |
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") | |
return loop.global_step % self.viz_every == 0 and not out_path.exists() | |
def _reset_camera(self): | |
try: | |
from celluloid import Camera | |
self._camera = Camera(plt.figure()) | |
except: | |
pass | |
def _maybe_dump_camera(self): | |
camera = self._camera | |
loop = self.get_running_train_loop() | |
if camera is not None and loop is not None: | |
animation = camera.animate() | |
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") | |
out_path.parent.mkdir(exist_ok=True, parents=True) | |
animation.save(out_path, writer="pillow", fps=4) | |
plt.close() | |
self._camera = None | |
def n_steps(self): | |
n = self.nfe | |
if self.method == "euler": | |
pass | |
elif self.method == "midpoint": | |
n //= 2 | |
elif self.method == "rk4": | |
n //= 4 | |
else: | |
raise ValueError(f"Unknown method: {self.method}") | |
return n | |
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): | |
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1)) | |
if self.visualizing: | |
self._reset_camera() | |
if self.verbose: | |
steps = trange(self.n_steps, desc="CFM inference") | |
else: | |
steps = range(self.n_steps) | |
ψt = ψ0 | |
for i in steps: | |
dt = ts[i + 1] - ts[i] | |
t = ts[i] | |
self._maybe_camera_snap(ψt=ψt, t=t) | |
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f) | |
self._maybe_camera_snap(ψt=ψt, t=ts[-1]) | |
ψ1 = ψt | |
del ψt | |
self._maybe_dump_camera() | |
return ψ1 | |
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): | |
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1) | |
class SinusodialTimeEmbedding(nn.Module): | |
def __init__(self, d_embed): | |
super().__init__() | |
self.d_embed = d_embed | |
assert d_embed % 2 == 0 | |
def forward(self, t): | |
t = t.unsqueeze(-1) # ... 1 | |
p = torch.linspace(0, 4, self.d_embed // 2).to(t) | |
while p.dim() < t.dim(): | |
p = p.unsqueeze(0) # ... d/2 | |
sin = torch.sin(t * 10**p) | |
cos = torch.cos(t * 10**p) | |
return torch.cat([sin, cos], dim=-1) | |
class CFM(nn.Module): | |
""" | |
This mixin is for general diffusion models. | |
ψ0 stands for the gaussian noise, and ψ1 is the data point. | |
Here we follow the CFM style: | |
The generation process (reverse process) is from t=0 to t=1. | |
The forward process is from t=1 to t=0. | |
""" | |
cond_dim: int | |
output_dim: int | |
time_emb_dim: int = 128 | |
viz_name: str = "cfm" | |
solver_nfe: int = 32 | |
solver_method: str = "midpoint" | |
time_mapping_divisor: int = 4 | |
def __post_init__(self): | |
super().__init__() | |
self.solver = Solver( | |
viz_name=self.viz_name, | |
viz_every=1, | |
nfe=self.solver_nfe, | |
method=self.solver_method, | |
time_mapping_divisor=self.time_mapping_divisor, | |
) | |
self.emb = SinusodialTimeEmbedding(self.time_emb_dim) | |
self.net = WN( | |
input_dim=self.output_dim, | |
output_dim=self.output_dim, | |
local_dim=self.cond_dim, | |
global_dim=self.time_emb_dim, | |
) | |
def _perturb(self, ψ1: Tensor, t: Tensor | None = None): | |
""" | |
Perturb ψ1 to ψt. | |
""" | |
raise NotImplementedError | |
def _sample_ψ0(self, x: Tensor): | |
""" | |
Args: | |
x: (b c t), which implies the shape of ψ0 | |
""" | |
shape = list(x.shape) | |
shape[1] = self.output_dim | |
if self.training: | |
g = None | |
else: | |
g = torch.Generator(device=x.device) | |
g.manual_seed(0) # deterministic sampling during eval | |
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g) | |
return ψ0 | |
def sigma(self): | |
return 1e-4 | |
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor): | |
""" | |
Eq (22) | |
""" | |
while t.dim() < ψ1.dim(): | |
t = t.unsqueeze(-1) | |
μ = t * ψ1 + (1 - t) * ψ0 | |
return μ + torch.randn_like(μ) * self.sigma | |
def _to_u(self, *, ψ1, ψ0: Tensor): | |
""" | |
Eq (21) | |
""" | |
return ψ1 - ψ0 | |
def _to_v(self, *, ψt, x, t: float | Tensor): | |
""" | |
Args: | |
ψt: (b c t) | |
x: (b c t) | |
t: (b) | |
Returns: | |
v: (b c t) | |
""" | |
if isinstance(t, (float, int)): | |
t = torch.full(ψt.shape[:1], t).to(ψt) | |
t = t.clamp(0, 1) # [0, 1) | |
g = self.emb(t) # (b d) | |
v = self.net(ψt, l=x, g=g) | |
return v | |
def compute_losses(self, x, y, ψ0) -> dict: | |
""" | |
Args: | |
x: (b c t) | |
y: (b c t) | |
Returns: | |
losses: dict | |
""" | |
t = torch.rand(len(x), device=x.device, dtype=x.dtype) | |
t = self.solver.time_mapping(t) | |
if ψ0 is None: | |
ψ0 = self._sample_ψ0(x) | |
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0) | |
v = self._to_v(ψt=ψt, t=t, x=x) | |
u = self._to_u(ψ1=y, ψ0=ψ0) | |
losses = dict(l1=F.l1_loss(v, u)) | |
return losses | |
def sample(self, x, ψ0=None, t0=0.0): | |
""" | |
Args: | |
x: (b c t) | |
Returns: | |
y: (b ... t) | |
""" | |
if ψ0 is None: | |
ψ0 = self._sample_ψ0(x) | |
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x) | |
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0) | |
return ψ1 | |
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0): | |
if y is None: | |
y = self.sample(x, ψ0=ψ0, t0=t0) | |
else: | |
self.losses = self.compute_losses(x, y, ψ0=ψ0) | |
return y | |