TEMPS / notebooks /nz.py
lauracabayol's picture
new notebook versions
b25063d
raw
history blame
7.15 kB
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: temps
# language: python
# name: temps
# ---
# # FIGURE 5 IN THE PAPER
# ## n(z) distributions
# %load_ext autoreload
# %autoreload 2
import pandas as pd
import numpy as np
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path
#matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule
from temps.plots import plot_nz
eval_methods=False
# ### LOAD DATA
#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')
# +
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
hdu_list = fits.open(parent_dir / filename_valid)
cat = Table(hdu_list[1].data).to_pandas()
cat = cat[cat['FLAG_PHOT']==0]
cat = cat[cat['mu_class_L07']==1]
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
cat = cat[cat['MAG_VIS']<25]
# -
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
ID = cat['ID']
VISmag = cat['MAG_VIS']
zsflag = cat['reliable_S15']
photoz_archive = Archive(path = parent_dir,only_zspec=False)
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
col, colerr = photoz_archive._to_colors(f, ferr)
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
if eval_methods:
dfs = {}
for il, lab in enumerate(['z','L15','DA']):
nn_features = EncoderPhotometry()
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
nn_z = MeasureZ(num_gauss=6)
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
temps_module = TempsModule(nn_features, nn_z)
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
return_pz=True)
# Create a DataFrame with the desired columns
df = pd.DataFrame(np.c_[ID, VISmag,z, odds, ztarget,zsflag, specz_or_photo],
columns=['ID','VISmag','z','odds', 'ztarget','zsflag','S15_L15_flag'])
# Calculate additional columns or operations if needed
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
# Drop any rows with NaN values
df = df.dropna()
# Assign the DataFrame to a key in the dictionary
dfs[lab] = df
# ### LOAD CATALOGUES IF AVAILABLE
if not eval_methods:
df_zs = pd.read_csv(parent_dir / 'predictions_specztraining.csv', header=0)
df_zsL15 = pd.read_csv(parent_dir / 'predictions_speczL15training.csv', header=0)
df_DA = pd.read_csv(parent_dir / 'predictions_speczDAtraining.csv', header=0)
dfs = {}
dfs['z'] = df_zs
dfs['L15'] = df_zsL15
dfs['DA'] = df_DA
# +
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Create figure and grid specification
fig = plt.figure(figsize=(8, 10))
gs = gridspec.GridSpec(5, 1, height_ratios=[0.1, 1, 1,1,1])
# Upper panel (very thin) with shaded areas
ax1 = plt.subplot(gs[0])
ax1.set_yticks([])
ax1.set_ylabel('Bins', fontsize=10)
# Define the ranges for shaded areas
#z_ranges = [[0.15, 0.35], [0.35, 0.55], [0.55, 0.85], [0.85, 1.05], [1.05, 1.35],
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
z_ranges = [[0.15, 0.5], [0.5, 1], [1, 1.5], [1.5,2]]#, [2, 3], [3,4]]#,
#[1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
colors = ['deepskyblue', 'forestgreen', 'coral', 'grey', 'pink', 'goldenrod',
'cyan', 'seagreen', 'salmon', 'steelblue', 'orange']
# Plot shaded areas
x_values = [0, 1, 2] # Example x values, adjust as needed
for i, (start, end) in enumerate(z_ranges):
ax1.fill_betweenx(x_values, start, end, color=colors[i], alpha=0.5)
# Middle panel (equally thick)
ax2 = plt.subplot(gs[1])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs['z'][(dfs['z']['ztarget'] > start) & (dfs['z']['ztarget'] < end)]
ax2.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
# Bottom panel (equally thick)
ax3 = plt.subplot(gs[2])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs['z'][(dfs['z']['z'] > start) & (dfs['z']['z'] < end)]
ax3.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
# Bottom panel (equally thick)
ax4 = plt.subplot(gs[3])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs['L15'][(dfs['L15']['z'] > start) & (dfs['L15']['z'] < end)]
print(len(dfplot_z))
ax4.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
ax5 = plt.subplot(gs[4])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs['DA'][(dfs['DA']['z'] > start) & (dfs['DA']['z'] < end)]
ax5.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
plt.tight_layout()
plt.show()
# -
def plot_nz(df_list,
zcuts = [0.1, 0.5, 1, 1.5, 2, 3, 4],
save=False):
# Plot properties
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 16
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
# Create subplots
fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True)
for i, df in enumerate(df_list):
dfplot = df_list[i].copy() # Assuming df_list contains dataframes
ax = axs[i] # Selecting the appropriate subplot
for iz in range(len(zcuts)-1):
dfplot_z = dfplot[(dfplot['ztarget'] > zcuts[iz]) & (dfplot['ztarget'] < zcuts[iz + 1])]
color = cmap(iz) # Get a different color for each redshift
zt_mean = np.median(dfplot_z.ztarget.values)
zp_mean = np.median(dfplot_z.z.values)
# Plot histogram on the selected subplot
ax.hist(dfplot_z.z, bins=50, color=color, histtype='step', linestyle='-', density=True, range=(0, 4))
ax.axvline(zt_mean, color=color, linestyle='-', lw=2)
ax.axvline(zp_mean, color=color, linestyle='--', lw=2)
ax.set_ylabel(f'Frequency', fontsize=14)
ax.grid(False)
ax.set_xlim(0, 3.5)
axs[-1].set_xlabel(f'$z$', fontsize=18)
if save:
plt.savefig(f'nz_hist.pdf', dpi=300, bbox_inches='tight')
plt.show()
plot_nz(df_list)