TEMPS / notebooks /nz.py
Laura Cabayol Garcia
running precommit
668e440
raw
history blame
7.58 kB
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: temps
# language: python
# name: temps
# ---
# # FIGURE 5 IN THE PAPER
# ## n(z) distributions
# %load_ext autoreload
# %autoreload 2
import pandas as pd
import numpy as np
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path
# matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule
from temps.plots import plot_nz
eval_methods = False
# ### LOAD DATA
# define here the directory containing the photometric catalogues
parent_dir = Path(
"/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5"
)
modules_dir = Path("../data/models/")
# +
filename_valid = "euclid_cosmos_DC2_S1_v2.1_valid_matched.fits"
hdu_list = fits.open(parent_dir / filename_valid)
cat = Table(hdu_list[1].data).to_pandas()
cat = cat[cat["FLAG_PHOT"] == 0]
cat = cat[cat["mu_class_L07"] == 1]
cat = cat[(cat["z_spec_S15"] > 0) | (cat["photo_z_L15"] > 0)]
cat = cat[cat["MAG_VIS"] < 25]
# -
ztarget = [
cat["z_spec_S15"].values[ii]
if cat["z_spec_S15"].values[ii] > 0
else cat["photo_z_L15"].values[ii]
for ii in range(len(cat))
]
specz_or_photo = [
0 if cat["z_spec_S15"].values[ii] > 0 else 1 for ii in range(len(cat))
]
ID = cat["ID"]
VISmag = cat["MAG_VIS"]
zsflag = cat["reliable_S15"]
photoz_archive = Archive(path=parent_dir, only_zspec=False)
f, ferr = photoz_archive._extract_fluxes(catalogue=cat)
col, colerr = photoz_archive._to_colors(f, ferr)
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
if eval_methods:
dfs = {}
for il, lab in enumerate(["z", "L15", "DA"]):
nn_features = EncoderPhotometry()
nn_features.load_state_dict(
torch.load(
modules_dir / f"modelF_{lab}.pt", map_location=torch.device("cpu")
)
)
nn_z = MeasureZ(num_gauss=6)
nn_z.load_state_dict(
torch.load(
modules_dir / f"modelZ_{lab}.pt", map_location=torch.device("cpu")
)
)
temps_module = TempsModule(nn_features, nn_z)
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col), return_pz=True)
# Create a DataFrame with the desired columns
df = pd.DataFrame(
np.c_[ID, VISmag, z, odds, ztarget, zsflag, specz_or_photo],
columns=["ID", "VISmag", "z", "odds", "ztarget", "zsflag", "S15_L15_flag"],
)
# Calculate additional columns or operations if needed
df["zwerr"] = (df.z - df.ztarget) / (1 + df.ztarget)
# Drop any rows with NaN values
df = df.dropna()
# Assign the DataFrame to a key in the dictionary
dfs[lab] = df
# ### LOAD CATALOGUES IF AVAILABLE
if not eval_methods:
df_zs = pd.read_csv(parent_dir / "predictions_specztraining.csv", header=0)
df_zsL15 = pd.read_csv(parent_dir / "predictions_speczL15training.csv", header=0)
df_DA = pd.read_csv(parent_dir / "predictions_speczDAtraining.csv", header=0)
dfs = {}
dfs["z"] = df_zs
dfs["L15"] = df_zsL15
dfs["DA"] = df_DA
# +
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Create figure and grid specification
fig = plt.figure(figsize=(8, 10))
gs = gridspec.GridSpec(5, 1, height_ratios=[0.1, 1, 1, 1, 1])
# Upper panel (very thin) with shaded areas
ax1 = plt.subplot(gs[0])
ax1.set_yticks([])
ax1.set_ylabel("Bins", fontsize=10)
# Define the ranges for shaded areas
# z_ranges = [[0.15, 0.35], [0.35, 0.55], [0.55, 0.85], [0.85, 1.05], [1.05, 1.35],
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
z_ranges = [[0.15, 0.5], [0.5, 1], [1, 1.5], [1.5, 2]] # , [2, 3], [3,4]]#,
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
colors = [
"deepskyblue",
"forestgreen",
"coral",
"grey",
"pink",
"goldenrod",
"cyan",
"seagreen",
"salmon",
"steelblue",
"orange",
]
# Plot shaded areas
x_values = [0, 1, 2] # Example x values, adjust as needed
for i, (start, end) in enumerate(z_ranges):
ax1.fill_betweenx(x_values, start, end, color=colors[i], alpha=0.5)
# Middle panel (equally thick)
ax2 = plt.subplot(gs[1])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs["z"][(dfs["z"]["ztarget"] > start) & (dfs["z"]["ztarget"] < end)]
ax2.hist(
dfplot_z.ztarget,
bins=50,
color=colors[i],
histtype="step",
linestyle="-",
density=True,
range=(0, 4),
)
# Bottom panel (equally thick)
ax3 = plt.subplot(gs[2])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs["z"][(dfs["z"]["z"] > start) & (dfs["z"]["z"] < end)]
ax3.hist(
dfplot_z.ztarget,
bins=50,
color=colors[i],
histtype="step",
linestyle="-",
density=True,
range=(0, 4),
)
# Bottom panel (equally thick)
ax4 = plt.subplot(gs[3])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs["L15"][(dfs["L15"]["z"] > start) & (dfs["L15"]["z"] < end)]
print(len(dfplot_z))
ax4.hist(
dfplot_z.ztarget,
bins=50,
color=colors[i],
histtype="step",
linestyle="-",
density=True,
range=(0, 4),
)
ax5 = plt.subplot(gs[4])
for i, (start, end) in enumerate(z_ranges):
dfplot_z = dfs["DA"][(dfs["DA"]["z"] > start) & (dfs["DA"]["z"] < end)]
ax5.hist(
dfplot_z.ztarget,
bins=50,
color=colors[i],
histtype="step",
linestyle="-",
density=True,
range=(0, 4),
)
plt.tight_layout()
plt.show()
# -
def plot_nz(df_list, zcuts=[0.1, 0.5, 1, 1.5, 2, 3, 4], save=False):
# Plot properties
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 16
cmap = plt.get_cmap("Dark2") # Choose a colormap for coloring lines
# Create subplots
fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True)
for i, df in enumerate(df_list):
dfplot = df_list[i].copy() # Assuming df_list contains dataframes
ax = axs[i] # Selecting the appropriate subplot
for iz in range(len(zcuts) - 1):
dfplot_z = dfplot[
(dfplot["ztarget"] > zcuts[iz]) & (dfplot["ztarget"] < zcuts[iz + 1])
]
color = cmap(iz) # Get a different color for each redshift
zt_mean = np.median(dfplot_z.ztarget.values)
zp_mean = np.median(dfplot_z.z.values)
# Plot histogram on the selected subplot
ax.hist(
dfplot_z.z,
bins=50,
color=color,
histtype="step",
linestyle="-",
density=True,
range=(0, 4),
)
ax.axvline(zt_mean, color=color, linestyle="-", lw=2)
ax.axvline(zp_mean, color=color, linestyle="--", lw=2)
ax.set_ylabel(f"Frequency", fontsize=14)
ax.grid(False)
ax.set_xlim(0, 3.5)
axs[-1].set_xlabel(f"$z$", fontsize=18)
if save:
plt.savefig(f"nz_hist.pdf", dpi=300, bbox_inches="tight")
plt.show()
plot_nz(df_list)