File size: 1,548 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn
import numpy as np


class NormGPS(nn.Module):
    def __init__(self, input_key="gps", output_key="x_0", normalize=True):
        super().__init__()
        self.input_key = input_key
        self.output_key = output_key
        self.normalize = normalize
        if self.normalize:
            self.register_buffer(
                "gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0)
            )

    def forward(self, batch):
        """Normalize latitude longtitude radians to -1, 1."""  # not used currently
        x = batch[self.input_key]
        if self.normalize:
            x = x * self.gps_normalize
        batch[self.output_key] = x
        return batch

class GPStoCartesian(nn.Module):
    def __init__(self, input_key="gps", output_key="x_0"):
        super().__init__()
        self.input_key = input_key
        self.output_key = output_key

    def forward(self, batch):
        """Project latitude longtitude radians to 3D coordinates."""
        x = batch[self.input_key]
        lat, lon = x[:, 0], x[:, 1]
        x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1)
        batch[self.output_key] = x
        return batch

class PrecomputedPreconditioning:
    def __init__(
        self,
        input_key="emb",
        output_key="emb",
    ):
        self.input_key = input_key
        self.output_key = output_key

    def __call__(self, batch, device=None):
        batch[self.output_key] = batch[self.input_key]
        return batch