TEMPS / insight /insight.py
lauracabayol's picture
optimized version working at low z
696a020
raw
history blame
5.89 kB
import torch
from torch.utils.data import DataLoader, dataset, TensorDataset
from torch import nn, optim
from torch.optim import lr_scheduler
import numpy as np
import pandas as pd
from astropy.io import fits
import os
from astropy.table import Table
from scipy.spatial import KDTree
from scipy.special import erf
class Insight_module():
""" Define class"""
def __init__(self, model, batch_size):
self.model=model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.batch_size=batch_size
def _get_dataloaders(self, input_data, target_data, val_fraction=0.1):
input_data = torch.Tensor(input_data)
target_data = torch.Tensor(target_data)
dataset = TensorDataset(input_data, target_data)
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
return loader_train, loader_val
def _loss_function(self,mean, std, logmix, true):
logerf = torch.log(erf(true.cpu()[:,None]/(np.sqrt(2)*std.detach().cpu())+1))
log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device)
log_prob = torch.logsumexp(log_prob, 1)
loss = -log_prob.mean()
return loss
def _to_numpy(self,x):
return x.detach().cpu().numpy()
def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
self.model = self.model.train()
loader_train, loader_val = self._get_dataloaders(input_data, target_data, val_fraction=0.1)
optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma =0.1)
self.model = self.model.to(self.device)
self.loss_train, self.loss_validation = [],[]
for epoch in range(nepochs):
for input_data, target_data in loader_train:
_loss_train, _loss_validation = [],[]
input_data = input_data.to(self.device)
target_data = target_data.to(self.device)
optimizer.zero_grad()
mu, logsig, logmix_coeff = self.model(input_data)
logsig = torch.clamp(logsig,-6,2)
sig = torch.exp(logsig)
#print(mu,sig,target_data,torch.exp(logmix_coeff))
loss = self._loss_function(mu, sig, logmix_coeff, target_data)
_loss_train.append(loss.item())
loss.backward()
optimizer.step()
scheduler.step()
self.loss_train.append(np.mean(_loss_train))
for input_data, target_data in loader_val:
input_data = input_data.to(self.device)
target_data = target_data.to(self.device)
mu, logsig, logmix_coeff = self.model(input_data)
logsig = torch.clamp(logsig,-6,2)
sig = torch.exp(logsig)
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
_loss_validation.append(loss_val.item())
self.loss_validation.append(np.mean(_loss_validation))
#print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
def get_photoz(self,input_data, target_data):
self.model = self.model.eval()
self.model = self.model.to(self.device)
input_data = input_data.to(self.device)
target_data = target_data.to(self.device)
for ii in range(len(input_data)):
mu, logsig, logmix_coeff = self.model(input_data)
logsig = torch.clamp(logsig,-6,2)
sig = torch.exp(logsig)
mix_coeff = torch.exp(logmix_coeff)
z = (mix_coeff * mu).sum(1)
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
return self._to_numpy(z),self._to_numpy(zerr)
#return model
def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
ydata,xlab = [],[]
for k in range(len(bin_edges)-1):
edge_min = bin_edges[k]
edge_max = bin_edges[k+1]
mean_mag = (edge_max + edge_min) / 2
if type_bin=='bin':
df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
elif type_bin=='cum':
df_plot = df_test[(df_test.imag < edge_max)]
else:
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
xlab.append(mean_mag)
if metric=='sig68':
ydata.append(sigma68(df_plot.zwerr))
elif metric=='bias':
ydata.append(np.mean(df_plot.zwerr))
elif metric=='nmad':
ydata.append(nmad(df_plot.zwerr))
elif metric=='outliers':
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
plt.xlabel(f'{xvariable}', fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.grid(False)
plt.show()