Spaces:
Runtime error
Runtime error
from torch import nn, optim | |
import torch | |
class Photoz_network(nn.Module): | |
def __init__(self, num_gauss=10, dropout_prob=0): | |
super(Photoz_network, self).__init__() | |
self.features = nn.Sequential( | |
nn.Linear(6, 10), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(10, 20), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(20, 50), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(50, 20), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(20, 10) | |
) | |
self.measure_mu = nn.Sequential( | |
nn.Linear(10, 20), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(20, num_gauss) | |
) | |
self.measure_coeffs = nn.Sequential( | |
nn.Linear(10, 20), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(20, num_gauss) | |
) | |
self.measure_sigma = nn.Sequential( | |
nn.Linear(10, 20), | |
nn.Dropout(dropout_prob), | |
nn.ReLU(), | |
nn.Linear(20, num_gauss) | |
) | |
def forward(self, x): | |
f = self.features(x) | |
mu = self.measure_mu(f) | |
sigma = self.measure_sigma(f) | |
logmix_coeff = self.measure_coeffs(f) | |
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None] | |
return mu, sigma, logmix_coeff | |