Spaces:
Running
Running
File size: 1,548 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch import nn
import numpy as np
class NormGPS(nn.Module):
def __init__(self, input_key="gps", output_key="x_0", normalize=True):
super().__init__()
self.input_key = input_key
self.output_key = output_key
self.normalize = normalize
if self.normalize:
self.register_buffer(
"gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0)
)
def forward(self, batch):
"""Normalize latitude longtitude radians to -1, 1.""" # not used currently
x = batch[self.input_key]
if self.normalize:
x = x * self.gps_normalize
batch[self.output_key] = x
return batch
class GPStoCartesian(nn.Module):
def __init__(self, input_key="gps", output_key="x_0"):
super().__init__()
self.input_key = input_key
self.output_key = output_key
def forward(self, batch):
"""Project latitude longtitude radians to 3D coordinates."""
x = batch[self.input_key]
lat, lon = x[:, 0], x[:, 1]
x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1)
batch[self.output_key] = x
return batch
class PrecomputedPreconditioning:
def __init__(
self,
input_key="emb",
output_key="emb",
):
self.input_key = input_key
self.output_key = output_key
def __call__(self, batch, device=None):
batch[self.output_key] = batch[self.input_key]
return batch
|