import torch from torch.utils.data import DataLoader, dataset, TensorDataset from torch import nn, optim from torch.optim import lr_scheduler import numpy as np import pandas as pd from astropy.io import fits import os from astropy.table import Table from scipy.spatial import KDTree from scipy.special import erf class Insight_module(): """ Define class""" def __init__(self, model, batch_size): self.model=model self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.batch_size=batch_size def _get_dataloaders(self, input_data, target_data, val_fraction=0.1): input_data = torch.Tensor(input_data) target_data = torch.Tensor(target_data) dataset = TensorDataset(input_data, target_data) trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1]) loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True) loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True) return loader_train, loader_val def _loss_function(self,mean, std, logmix, true): logerf = torch.log(erf(true.cpu()[:,None]/(np.sqrt(2)*std.detach().cpu())+1)) log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device) log_prob = torch.logsumexp(log_prob, 1) loss = -log_prob.mean() return loss def _to_numpy(self,x): return x.detach().cpu().numpy() def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ): self.model = self.model.train() loader_train, loader_val = self._get_dataloaders(input_data, target_data, val_fraction=0.1) optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma =0.1) self.model = self.model.to(self.device) self.loss_train, self.loss_validation = [],[] for epoch in range(nepochs): for input_data, target_data in loader_train: _loss_train, _loss_validation = [],[] input_data = input_data.to(self.device) target_data = target_data.to(self.device) optimizer.zero_grad() mu, logsig, logmix_coeff = self.model(input_data) logsig = torch.clamp(logsig,-6,2) sig = torch.exp(logsig) #print(mu,sig,target_data,torch.exp(logmix_coeff)) loss = self._loss_function(mu, sig, logmix_coeff, target_data) _loss_train.append(loss.item()) loss.backward() optimizer.step() scheduler.step() self.loss_train.append(np.mean(_loss_train)) for input_data, target_data in loader_val: input_data = input_data.to(self.device) target_data = target_data.to(self.device) mu, logsig, logmix_coeff = self.model(input_data) logsig = torch.clamp(logsig,-6,2) sig = torch.exp(logsig) loss_val = self._loss_function(mu, sig, logmix_coeff, target_data) _loss_validation.append(loss_val.item()) self.loss_validation.append(np.mean(_loss_validation)) #print(f'training_loss:{loss}',f'testing_loss:{loss_val}') def get_photoz(self,input_data, target_data): self.model = self.model.eval() self.model = self.model.to(self.device) input_data = input_data.to(self.device) target_data = target_data.to(self.device) for ii in range(len(input_data)): mu, logsig, logmix_coeff = self.model(input_data) logsig = torch.clamp(logsig,-6,2) sig = torch.exp(logsig) mix_coeff = torch.exp(logmix_coeff) z = (mix_coeff * mu).sum(1) zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1)) return self._to_numpy(z),self._to_numpy(zerr) #return model def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'): bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins)) ydata,xlab = [],[] for k in range(len(bin_edges)-1): edge_min = bin_edges[k] edge_max = bin_edges[k+1] mean_mag = (edge_max + edge_min) / 2 if type_bin=='bin': df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)] elif type_bin=='cum': df_plot = df_test[(df_test.imag < edge_max)] else: raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported") xlab.append(mean_mag) if metric=='sig68': ydata.append(sigma68(df_plot.zwerr)) elif metric=='bias': ydata.append(np.mean(df_plot.zwerr)) elif metric=='nmad': ydata.append(nmad(df_plot.zwerr)) elif metric=='outliers': ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot)) plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '') plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18) plt.xlabel(f'{xvariable}', fontsize = 16) plt.xticks(fontsize = 14) plt.yticks(fontsize = 14) plt.grid(False) plt.show()