TEMPS / insight /utils.py
lauracabayol's picture
optimized version working at low z
696a020
raw
history blame
2.39 kB
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
def nmad(data):
return 1.4826 * np.median(np.abs(data - np.median(data)))
def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
def plot_photoz_estimates(df, nbins,xvariable,metric, type_bin='bin'):
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
ydata,xdata = [],[]
for k in range(len(bin_edges)-1):
edge_min = bin_edges[k]
edge_max = bin_edges[k+1]
mean_mag = (edge_max + edge_min) / 2
if type_bin=='bin':
df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
elif type_bin=='cum':
df_plot = df[(df[xvariable] < edge_max)]
else:
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
xdata.append(mean_mag)
if metric=='sig68':
ydata.append(sigma68(df_plot.zwerr))
ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
elif metric=='bias':
ydata.append(np.median(df_plot.zwerr))
ylab=r'Median $[\Delta z]$'
elif metric=='nmad':
ydata.append(nmad(df_plot.zwerr))
ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
elif metric=='outliers':
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot) *100)
ylab=r'$\eta$ [%]'
if xvariable=='VISmag':
xlab='VIS'
elif xvariable=='zs':
xlab=r'$z_{\rm spec}$'
elif xvariable=='z':
xlab=r'$z$'
plt.plot(xdata,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
plt.ylabel(f'{ylab}', fontsize = 18)
plt.xlabel(f'{xlab}', fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.grid(False)
plt.show()
return
def plot_nz(df, bins=np.arange(0,5,0.2)):
kwargs=dict( bins=bins,alpha=0.5)
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
counts, _, =np.histogram(df.z.values, bins=bins)
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
#plt.legend(fontsize=14)
plt.xlabel(r'Redshift', fontsize=14)
plt.ylabel(r'Counts', fontsize=14)
plt.yscale('log')
plt.show()
return