from typing import Any import pytorch_lightning as L import torch import torch.nn as nn from hydra.utils import instantiate import copy import pandas as pd import numpy as np from tqdm import tqdm from utils.manifolds import Sphere from torch.func import jacrev, vjp, vmap from torchdiffeq import odeint from geoopt import ProductManifold, Euclidean from models.samplers.riemannian_flow_sampler import ode_riemannian_flow_sampler class DiffGeolocalizer(L.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg self.network = instantiate(cfg.network) # self.network = torch.compile(self.network, fullgraph=True) self.input_dim = cfg.network.input_dim self.train_noise_scheduler = instantiate(cfg.train_noise_scheduler) self.inference_noise_scheduler = instantiate(cfg.inference_noise_scheduler) self.data_preprocessing = instantiate(cfg.data_preprocessing) self.cond_preprocessing = instantiate(cfg.cond_preprocessing) self.preconditioning = instantiate(cfg.preconditioning) self.ema_network = copy.deepcopy(self.network).requires_grad_(False) self.ema_network.eval() self.postprocessing = instantiate(cfg.postprocessing) self.val_sampler = instantiate(cfg.val_sampler) self.test_sampler = instantiate(cfg.test_sampler) self.loss = instantiate(cfg.loss)( self.train_noise_scheduler, ) self.val_metrics = instantiate(cfg.val_metrics) self.test_metrics = instantiate(cfg.test_metrics) self.manifold = instantiate(cfg.manifold) if hasattr(cfg, "manifold") else None self.interpolant = cfg.interpolant def training_step(self, batch, batch_idx): with torch.no_grad(): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] loss = self.loss(self.preconditioning, self.network, batch).mean() self.log( "train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, batch_size=batch_size, ) return loss def on_before_optimizer_step(self, optimizer): if self.global_step == 0: no_grad = [] for name, param in self.network.named_parameters(): if param.grad is None: no_grad.append(name) if len(no_grad) > 0: print("Parameters without grad:") print(no_grad) def on_validation_start(self): self.validation_generator = torch.Generator(device=self.device).manual_seed( 3407 ) self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( 3407 ) def validation_step(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] loss = self.loss( self.preconditioning, self.network, batch, generator=self.validation_generator, ).mean() self.log( "val/loss", loss, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) if hasattr(self, "ema_model"): loss_ema = self.loss( self.preconditioning, self.ema_network, batch, generator=self.validation_generator_ema, ).mean() self.log( "val/loss_ema", loss_ema, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) # nll = -self.compute_exact_loglikelihood(batch).mean() # self.log( # "val/nll", # nll, # sync_dist=True, # on_step=False, # on_epoch=True, # batch_size=batch_size, # ) # def on_validation_epoch_end(self): # metrics = self.val_metrics.compute() # for metric_name, metric_value in metrics.items(): # self.log( # f"val/{metric_name}", # metric_value, # sync_dist=True, # on_step=False, # on_epoch=True, # ) def on_test_start(self): self.test_generator = torch.Generator(device=self.device).manual_seed(3407) def test_step_simple(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] if isinstance(self.manifold, Sphere): x_N = self.manifold.random_base( batch_size, self.input_dim, device=self.device, ) x_N = x_N.reshape(batch_size, self.input_dim) else: x_N = torch.randn( batch_size, self.input_dim, device=self.device, generator=self.test_generator, ) cond = batch[self.cfg.cond_preprocessing.output_key] samples = self.sample( x_N=x_N, cond=cond, stage="val", generator=self.test_generator, cfg=self.cfg.cfg_rate, ) self.test_metrics.update({"gps": samples}, batch) if self.cfg.compute_nll: nll = -self.compute_exact_loglikelihood(batch, cfg=0).mean() self.log( "test/NLL", nll, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) def test_best_nll(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] num_sample_per_cond = 32 if isinstance(self.manifold, Sphere): x_N = self.manifold.random_base( batch_size * num_sample_per_cond, self.input_dim, device=self.device, ) x_N = x_N.reshape(batch_size * num_sample_per_cond, self.input_dim) else: x_N = torch.randn( batch_size * num_sample_per_cond, self.input_dim, device=self.device, generator=self.test_generator, ) cond = ( batch[self.cfg.cond_preprocessing.output_key] .unsqueeze(1) .repeat(1, num_sample_per_cond, 1) .view(-1, batch[self.cfg.cond_preprocessing.output_key].shape[-1]) ) samples = self.sample_distribution( x_N, cond, sampling_batch_size=32768, stage="val", generator=self.test_generator, cfg=0, ) samples = samples.view(batch_size * num_sample_per_cond, -1) batch_swarm = {"gps": samples, "emb": cond} nll_batch = -self.compute_exact_loglikelihood(batch_swarm, cfg=0) nll_batch = nll_batch.view(batch_size, num_sample_per_cond, -1) nll_best = nll_batch[ torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) ] self.log( "test/best_nll", nll_best.mean(), sync_dist=True, on_step=False, on_epoch=True, ) samples = samples.view(batch_size, num_sample_per_cond, -1)[ torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) ] self.test_metrics.update({"gps": samples}, batch) def test_step(self, batch, batch_idx): if self.cfg.compute_swarms: self.test_best_nll(batch, batch_idx) else: self.test_step_simple(batch, batch_idx) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def configure_optimizers(self): if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) parameters_names_wd = [ name for name in parameters_names_wd if "bias" not in name ] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.network.named_parameters() if n in parameters_names_wd ], "weight_decay": self.cfg.optimizer.optim.weight_decay, "layer_adaptation": True, }, { "params": [ p for n, p in self.network.named_parameters() if n not in parameters_names_wd ], "weight_decay": 0.0, "layer_adaptation": False, }, ] optimizer = instantiate( self.cfg.optimizer.optim, optimizer_grouped_parameters ) else: optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) if "lr_scheduler" in self.cfg: scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] else: return optimizer def lr_scheduler_step(self, scheduler, metric): scheduler.step(self.global_step) def sample( self, batch_size=None, cond=None, x_N=None, num_steps=None, stage="test", cfg=0, generator=None, return_trajectories=False, postprocessing=True, ): if x_N is None: assert batch_size is not None if isinstance(self.manifold, Sphere): x_N = self.manifold.random_base( batch_size, self.input_dim, device=self.device ) x_N = x_N.reshape(batch_size, self.input_dim) else: x_N = torch.randn(batch_size, self.input_dim, device=self.device) batch = {"y": x_N} if stage == "val": sampler = self.val_sampler elif stage == "test": sampler = self.test_sampler else: raise ValueError(f"Unknown stage {stage}") batch[self.cfg.cond_preprocessing.input_key] = cond batch = self.cond_preprocessing(batch, device=self.device) if num_steps is None: output = sampler( self.ema_model, batch, conditioning_keys=self.cfg.cond_preprocessing.output_key, scheduler=self.inference_noise_scheduler, cfg_rate=cfg, generator=generator, return_trajectories=return_trajectories, ) else: output = sampler( self.ema_model, batch, conditioning_keys=self.cfg.cond_preprocessing.output_key, scheduler=self.inference_noise_scheduler, num_steps=num_steps, cfg_rate=cfg, generator=generator, return_trajectories=return_trajectories, ) if return_trajectories: return ( self.postprocessing(output[0]) if postprocessing else output[0], [ self.postprocessing(frame) if postprocessing else frame for frame in output[1] ], ) else: return self.postprocessing(output) if postprocessing else output def sample_distribution( self, x_N, cond, sampling_batch_size=2048, num_steps=None, stage="test", cfg=0, generator=None, return_trajectories=False, ): if return_trajectories: x_0 = [] trajectories = [] i = -1 for i in range(x_N.shape[0] // sampling_batch_size): x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] cond_batch = cond[ i * sampling_batch_size : (i + 1) * sampling_batch_size ] out, trajectories = self.sample( cond=cond_batch, x_N=x_N_batch, num_steps=num_steps, stage=stage, cfg=cfg, generator=generator, return_trajectories=return_trajectories, ) x_0.append(out) trajectories.append(trajectories) if x_N.shape[0] % sampling_batch_size != 0: x_N_batch = x_N[(i + 1) * sampling_batch_size :] cond_batch = cond[(i + 1) * sampling_batch_size :] out, trajectories = self.sample( cond=cond_batch, x_N=x_N_batch, num_steps=num_steps, stage=stage, cfg=cfg, generator=generator, return_trajectories=return_trajectories, ) x_0.append(out) trajectories.append(trajectories) x_0 = torch.cat(x_0, dim=1) trajectories = [torch.cat(frame, dim=1) for frame in trajectories] return x_0, trajectories else: x_0 = [] i = -1 for i in range(x_N.shape[0] // sampling_batch_size): x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] cond_batch = cond[ i * sampling_batch_size : (i + 1) * sampling_batch_size ] out = self.sample( cond=cond_batch, x_N=x_N_batch, num_steps=num_steps, stage=stage, cfg=cfg, generator=generator, return_trajectories=return_trajectories, ) x_0.append(out) if x_N.shape[0] % sampling_batch_size != 0: x_N_batch = x_N[(i + 1) * sampling_batch_size :] cond_batch = cond[(i + 1) * sampling_batch_size :] out = self.sample( cond=cond_batch, x_N=x_N_batch, num_steps=num_steps, stage=stage, cfg=cfg, generator=generator, return_trajectories=return_trajectories, ) x_0.append(out) x_0 = torch.cat(x_0, dim=0) return x_0 def model(self, *args, **kwargs): return self.preconditioning(self.network, *args, **kwargs) def ema_model(self, *args, **kwargs): return self.preconditioning(self.ema_network, *args, **kwargs) def compute_exact_loglikelihood( self, batch=None, x_1=None, cond=None, t1=1.0, num_steps=1000, rademacher=False, data_preprocessing=True, cfg=0, ): nfe = [0] if batch is None: batch = {"x_0": x_1, "emb": cond} if data_preprocessing: batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) timesteps = self.inference_noise_scheduler( torch.linspace(0, t1, 2).to(batch["x_0"]) ) with torch.inference_mode(mode=False): def odefunc(t, tensor): nfe[0] += 1 t = t.to(tensor) gamma = self.inference_noise_scheduler(t) x = tensor[..., : self.input_dim] y = batch["emb"] def vecfield(x, y): if cfg > 0: batch_vecfield = { "y": x, "emb": y, "gamma": gamma.reshape(-1), } model_output_cond = self.ema_model(batch_vecfield) batch_vecfield_uncond = { "y": x, "emb": torch.zeros_like(y), "gamma": gamma.reshape(-1), } model_output_uncond = self.ema_model(batch_vecfield_uncond) model_output = model_output_cond + cfg * ( model_output_cond - model_output_uncond ) else: batch_vecfield = { "y": x, "emb": y, "gamma": gamma.reshape(-1), } model_output = self.ema_model(batch_vecfield) if self.interpolant == "flow_matching": d_gamma = self.inference_noise_scheduler.derivative(t).reshape( -1, 1 ) return d_gamma * model_output elif self.interpolant == "diffusion": alpha_t = self.inference_noise_scheduler.alpha(t).reshape(-1, 1) return ( -1 / 2 * (alpha_t * x - torch.abs(alpha_t) * model_output) ) else: raise ValueError(f"Unknown interpolant {self.interpolant}") if rademacher: v = torch.randint_like(x, 2) * 2 - 1 else: v = None dx, div = output_and_div(vecfield, x, y, v=v) div = div.reshape(-1, 1) del t, x return torch.cat([dx, div], dim=-1) x_1 = batch["x_0"] state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) with torch.no_grad(): if False and isinstance(self.manifold, Sphere): print("Riemannian flow sampler") product_man = ProductManifold( (self.manifold, self.input_dim), (Euclidean(), 1) ) state0 = ode_riemannian_flow_sampler( odefunc, state1, manifold=product_man, scheduler=self.inference_noise_scheduler, num_steps=num_steps, ) else: print("ODE solver") state0 = odeint( odefunc, state1, t=torch.linspace(0, t1, 2).to(batch["x_0"]), atol=1e-6, rtol=1e-6, method="dopri5", options={"min_step": 1e-5}, )[-1] x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] if self.manifold is not None: x_0 = self.manifold.projx(x_0) logp0 = self.manifold.base_logprob(x_0) else: logp0 = ( -1 / 2 * (x_0**2).sum(dim=-1) - self.input_dim * torch.log(torch.tensor(2 * np.pi, device=x_0.device)) / 2 ) print(f"nfe: {nfe[0]}") logp1 = logp0 + logdetjac logp1 = logp1 / (self.input_dim * np.log(2)) return logp1 def get_parameter_names(model, forbidden_layer_types): """ Returns the names of the model parameters that are not inside a forbidden layer. Taken from HuggingFace transformers. """ result = [] for name, child in model.named_children(): result += [ f"{name}.{n}" for n in get_parameter_names(child, forbidden_layer_types) if not isinstance(child, tuple(forbidden_layer_types)) ] # Add model specific parameters (defined with nn.Parameter) since they are not in any child. result += list(model._parameters.keys()) return result # for likelihood computation def div_fn(u): """Accepts a function u:R^D -> R^D.""" J = jacrev(u, argnums=0) return lambda x, y: torch.trace(J(x, y).squeeze(0)) def output_and_div(vecfield, x, y, v=None): if v is None: dx = vecfield(x, y) div = vmap(div_fn(vecfield))(x, y) else: vecfield_x = lambda x: vecfield(x, y) dx, vjpfunc = vjp(vecfield_x, x) vJ = vjpfunc(v)[0] div = torch.sum(vJ * v, dim=-1) return dx, div class VonFisherGeolocalizer(L.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg self.network = instantiate(cfg.network) # self.network = torch.compile(self.network, fullgraph=True) self.input_dim = cfg.network.input_dim self.data_preprocessing = instantiate(cfg.data_preprocessing) self.cond_preprocessing = instantiate(cfg.cond_preprocessing) self.preconditioning = instantiate(cfg.preconditioning) self.ema_network = copy.deepcopy(self.network).requires_grad_(False) self.ema_network.eval() self.postprocessing = instantiate(cfg.postprocessing) self.val_sampler = instantiate(cfg.val_sampler) self.test_sampler = instantiate(cfg.test_sampler) self.loss = instantiate(cfg.loss)() self.val_metrics = instantiate(cfg.val_metrics) self.test_metrics = instantiate(cfg.test_metrics) def training_step(self, batch, batch_idx): with torch.no_grad(): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] loss = self.loss(self.preconditioning, self.network, batch).mean() self.log( "train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, batch_size=batch_size, ) return loss def on_before_optimizer_step(self, optimizer): if self.global_step == 0: no_grad = [] for name, param in self.network.named_parameters(): if param.grad is None: no_grad.append(name) if len(no_grad) > 0: print("Parameters without grad:") print(no_grad) def on_validation_start(self): self.validation_generator = torch.Generator(device=self.device).manual_seed( 3407 ) self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( 3407 ) def validation_step(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] loss = self.loss( self.preconditioning, self.network, batch, generator=self.validation_generator, ).mean() self.log( "val/loss", loss, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) if hasattr(self, "ema_model"): loss_ema = self.loss( self.preconditioning, self.ema_network, batch, generator=self.validation_generator_ema, ).mean() self.log( "val/loss_ema", loss_ema, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) def on_test_start(self): self.test_generator = torch.Generator(device=self.device).manual_seed(3407) def test_step(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] cond = batch[self.cfg.cond_preprocessing.output_key] samples = self.sample(cond=cond, stage="test") self.test_metrics.update({"gps": samples}, batch) nll = -self.compute_exact_loglikelihood(batch).mean() self.log( "test/NLL", nll, sync_dist=True, on_step=False, on_epoch=True, batch_size=batch_size, ) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, ) def configure_optimizers(self): if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) parameters_names_wd = [ name for name in parameters_names_wd if "bias" not in name ] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.network.named_parameters() if n in parameters_names_wd ], "weight_decay": self.cfg.optimizer.optim.weight_decay, "layer_adaptation": True, }, { "params": [ p for n, p in self.network.named_parameters() if n not in parameters_names_wd ], "weight_decay": 0.0, "layer_adaptation": False, }, ] optimizer = instantiate( self.cfg.optimizer.optim, optimizer_grouped_parameters ) else: optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) if "lr_scheduler" in self.cfg: scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] else: return optimizer def lr_scheduler_step(self, scheduler, metric): scheduler.step(self.global_step) def sample( self, batch_size=None, cond=None, postprocessing=True, stage="val", ): batch = {} if stage == "val": sampler = self.val_sampler elif stage == "test": sampler = self.test_sampler else: raise ValueError(f"Unknown stage {stage}") batch[self.cfg.cond_preprocessing.input_key] = cond batch = self.cond_preprocessing(batch, device=self.device) output = sampler( self.ema_model, batch, ) return self.postprocessing(output) if postprocessing else output def model(self, *args, **kwargs): return self.preconditioning(self.network, *args, **kwargs) def ema_model(self, *args, **kwargs): return self.preconditioning(self.ema_network, *args, **kwargs) def compute_exact_loglikelihood( self, batch=None, ): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) return -self.loss(self.preconditioning, self.ema_network, batch) class RandomGeolocalizer(L.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg self.test_metrics = instantiate(cfg.test_metrics) self.data_preprocessing = instantiate(cfg.data_preprocessing) self.cond_preprocessing = instantiate(cfg.cond_preprocessing) self.postprocessing = instantiate(cfg.postprocessing) def test_step(self, batch, batch_idx): batch = self.data_preprocessing(batch) batch = self.cond_preprocessing(batch) batch_size = batch["x_0"].shape[0] samples = torch.randn(batch_size, 3, device=self.device) samples = samples / samples.norm(dim=-1, keepdim=True) samples = self.postprocessing(samples) self.test_metrics.update({"gps": samples}, batch) def on_test_epoch_end(self): metrics = self.test_metrics.compute() for metric_name, metric_value in metrics.items(): self.log( f"test/{metric_name}", metric_value, sync_dist=True, on_step=False, on_epoch=True, )