from models.networks.utils import UnormGPS import torch.nn as nn from torch.nn.functional import tanh import torch class RegressionHead(nn.Module): def __init__(self, use_tanh=False): super().__init__() self.unorm = UnormGPS() self.use_tanh = use_tanh def forward(self, x): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ if self.use_tanh: x = tanh(x) gps = self.unorm(x) return {"gps": gps} class RegressionHeadAngle(nn.Module): def __init__(self): super().__init__() self.unorm = UnormGPS() def forward(self, x): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ x1 = x[:, 0].pow(2) x2 = x[:, 1].pow(2) x3 = x[:, 2].pow(2) x4 = x[:, 3].pow(2) cos_lambda = x1 / (x1 + x2) sin_lambda = x2 / (x1 + x2) cos_phi = x3 / (x3 + x4) sin_phi = x4 / (x3 + x4) lbd = torch.atan2(sin_lambda, cos_lambda) phi = torch.atan2(sin_phi, cos_phi) gps = torch.cat((lbd.unsqueeze(1), phi.unsqueeze(1)), dim=1) # gps = self.unorm(x) return {"gps": gps}