File size: 3,966 Bytes
26b28ad
c9354dd
d307831
 
 
 
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d307831
c212435
 
d307831
c212435
 
696a020
c212435
 
696a020
c212435
 
696a020
c212435
 
57fa8fc
c212435
57fa8fc
c9354dd
 
 
 
 
 
 
 
 
d307831
57fa8fc
d307831
 
 
 
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d307831
57fa8fc
 
c212435
696a020
c212435
 
57fa8fc
c212435
 
 
696a020
c212435
 
57fa8fc
c212435
 
 
696a020
c212435
 
57fa8fc
c212435
57fa8fc
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
c212435
 
 
26b28ad
c9354dd
 
c212435
57fa8fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torch import nn
import torch.nn.functional as F


class EncoderPhotometry(nn.Module):
    """Encoder for photometric data.

    This neural network encodes photometric features into a lower-dimensional representation.

    Attributes:
        features (nn.Sequential): A sequential container of layers used for encoding.
    """

    def __init__(self, input_dim: int = 6, dropout_prob: float = 0) -> None:
        """Initializes the EncoderPhotometry module.

        Args:
            input_dim (int): Number of input features (default is 6).
            dropout_prob (float): Probability of dropout (default is 0).
        """
        super(EncoderPhotometry, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(input_dim, 10),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(10, 20),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(20, 50),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(20, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the encoder.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).

        Returns:
            torch.Tensor: Log softmax output of shape (batch_size, 10).
        """
        f = self.features(x)
        f = F.log_softmax(f, dim=1)
        return f


class MeasureZ(nn.Module):
    """Model to measure redshift parameters.

    This model estimates the parameters of a mixture of Gaussians used for measuring redshift.

    Attributes:
        ngaussians (int): Number of Gaussian components in the mixture.
        measure_mu (nn.Sequential): Sequential model to measure the mean (mu).
        measure_coeffs (nn.Sequential): Sequential model to measure the mixing coefficients.
        measure_sigma (nn.Sequential): Sequential model to measure the standard deviation (sigma).
    """

    def __init__(self, num_gauss: int = 10, dropout_prob: float = 0) -> None:
        """Initializes the MeasureZ module.

        Args:
            num_gauss (int): Number of Gaussian components (default is 10).
            dropout_prob (float): Probability of dropout (default is 0).
        """
        super(MeasureZ, self).__init__()

        self.ngaussians = num_gauss
        self.measure_mu = nn.Sequential(
            nn.Linear(10, 20),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(20, num_gauss),
        )

        self.measure_coeffs = nn.Sequential(
            nn.Linear(10, 20),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(20, num_gauss),
        )

        self.measure_sigma = nn.Sequential(
            nn.Linear(10, 20),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(20, num_gauss),
        )

    def forward(
        self, f: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass to measure redshift parameters.

        Args:
            f (torch.Tensor): Input tensor of shape (batch_size, 10).

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
                - mu (torch.Tensor): Mean parameters of shape (batch_size, num_gauss).
                - sigma (torch.Tensor): Standard deviation parameters of shape (batch_size, num_gauss).
                - logmix_coeff (torch.Tensor): Log mixing coefficients of shape (batch_size, num_gauss).
        """
        mu = self.measure_mu(f)
        sigma = self.measure_sigma(f)
        logmix_coeff = self.measure_coeffs(f)

        # Normalize logmix_coeff to get valid mixture coefficients
        logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, dim=1, keepdim=True)

        return mu, sigma, logmix_coeff