import torch from torch import nn import numpy as np class NormGPS(nn.Module): def __init__(self, input_key="gps", output_key="x_0", normalize=True): super().__init__() self.input_key = input_key self.output_key = output_key self.normalize = normalize if self.normalize: self.register_buffer( "gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0) ) def forward(self, batch): """Normalize latitude longtitude radians to -1, 1.""" # not used currently x = batch[self.input_key] if self.normalize: x = x * self.gps_normalize batch[self.output_key] = x return batch class GPStoCartesian(nn.Module): def __init__(self, input_key="gps", output_key="x_0"): super().__init__() self.input_key = input_key self.output_key = output_key def forward(self, batch): """Project latitude longtitude radians to 3D coordinates.""" x = batch[self.input_key] lat, lon = x[:, 0], x[:, 1] x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1) batch[self.output_key] = x return batch class PrecomputedPreconditioning: def __init__( self, input_key="emb", output_key="emb", ): self.input_key = input_key self.output_key = output_key def __call__(self, batch, device=None): batch[self.output_key] = batch[self.input_key] return batch