Spaces:
Runtime error
Runtime error
# --- | |
# jupyter: | |
# jupytext: | |
# text_representation: | |
# extension: .py | |
# format_name: light | |
# format_version: '1.5' | |
# jupytext_version: 1.16.2 | |
# kernelspec: | |
# display_name: temps | |
# language: python | |
# name: temps | |
# --- | |
# # FIGURE 5 IN THE PAPER | |
# ## n(z) distributions | |
# %load_ext autoreload | |
# %autoreload 2 | |
import pandas as pd | |
import numpy as np | |
from astropy.io import fits | |
from astropy.table import Table | |
import torch | |
from pathlib import Path | |
# matplotlib settings | |
from matplotlib import rcParams | |
import matplotlib.pyplot as plt | |
rcParams["mathtext.fontset"] = "stix" | |
rcParams["font.family"] = "STIXGeneral" | |
from temps.archive import Archive | |
from temps.utils import nmad | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
from temps.temps import TempsModule | |
from temps.plots import plot_nz | |
eval_methods = False | |
# ### LOAD DATA | |
# define here the directory containing the photometric catalogues | |
parent_dir = Path( | |
"/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5" | |
) | |
modules_dir = Path("../data/models/") | |
# + | |
filename_valid = "euclid_cosmos_DC2_S1_v2.1_valid_matched.fits" | |
hdu_list = fits.open(parent_dir / filename_valid) | |
cat = Table(hdu_list[1].data).to_pandas() | |
cat = cat[cat["FLAG_PHOT"] == 0] | |
cat = cat[cat["mu_class_L07"] == 1] | |
cat = cat[(cat["z_spec_S15"] > 0) | (cat["photo_z_L15"] > 0)] | |
cat = cat[cat["MAG_VIS"] < 25] | |
# - | |
ztarget = [ | |
cat["z_spec_S15"].values[ii] | |
if cat["z_spec_S15"].values[ii] > 0 | |
else cat["photo_z_L15"].values[ii] | |
for ii in range(len(cat)) | |
] | |
specz_or_photo = [ | |
0 if cat["z_spec_S15"].values[ii] > 0 else 1 for ii in range(len(cat)) | |
] | |
ID = cat["ID"] | |
VISmag = cat["MAG_VIS"] | |
zsflag = cat["reliable_S15"] | |
photoz_archive = Archive(path=parent_dir, only_zspec=False) | |
f, ferr = photoz_archive._extract_fluxes(catalogue=cat) | |
col, colerr = photoz_archive._to_colors(f, ferr) | |
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT | |
if eval_methods: | |
dfs = {} | |
for il, lab in enumerate(["z", "L15", "DA"]): | |
nn_features = EncoderPhotometry() | |
nn_features.load_state_dict( | |
torch.load( | |
modules_dir / f"modelF_{lab}.pt", map_location=torch.device("cpu") | |
) | |
) | |
nn_z = MeasureZ(num_gauss=6) | |
nn_z.load_state_dict( | |
torch.load( | |
modules_dir / f"modelZ_{lab}.pt", map_location=torch.device("cpu") | |
) | |
) | |
temps_module = TempsModule(nn_features, nn_z) | |
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col), return_pz=True) | |
# Create a DataFrame with the desired columns | |
df = pd.DataFrame( | |
np.c_[ID, VISmag, z, odds, ztarget, zsflag, specz_or_photo], | |
columns=["ID", "VISmag", "z", "odds", "ztarget", "zsflag", "S15_L15_flag"], | |
) | |
# Calculate additional columns or operations if needed | |
df["zwerr"] = (df.z - df.ztarget) / (1 + df.ztarget) | |
# Drop any rows with NaN values | |
df = df.dropna() | |
# Assign the DataFrame to a key in the dictionary | |
dfs[lab] = df | |
# ### LOAD CATALOGUES IF AVAILABLE | |
if not eval_methods: | |
df_zs = pd.read_csv(parent_dir / "predictions_specztraining.csv", header=0) | |
df_zsL15 = pd.read_csv(parent_dir / "predictions_speczL15training.csv", header=0) | |
df_DA = pd.read_csv(parent_dir / "predictions_speczDAtraining.csv", header=0) | |
dfs = {} | |
dfs["z"] = df_zs | |
dfs["L15"] = df_zsL15 | |
dfs["DA"] = df_DA | |
# + | |
import matplotlib.pyplot as plt | |
from matplotlib import gridspec | |
# Create figure and grid specification | |
fig = plt.figure(figsize=(8, 10)) | |
gs = gridspec.GridSpec(5, 1, height_ratios=[0.1, 1, 1, 1, 1]) | |
# Upper panel (very thin) with shaded areas | |
ax1 = plt.subplot(gs[0]) | |
ax1.set_yticks([]) | |
ax1.set_ylabel("Bins", fontsize=10) | |
# Define the ranges for shaded areas | |
# z_ranges = [[0.15, 0.35], [0.35, 0.55], [0.55, 0.85], [0.85, 1.05], [1.05, 1.35], | |
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]] | |
z_ranges = [[0.15, 0.5], [0.5, 1], [1, 1.5], [1.5, 2]] # , [2, 3], [3,4]]#, | |
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]] | |
colors = [ | |
"deepskyblue", | |
"forestgreen", | |
"coral", | |
"grey", | |
"pink", | |
"goldenrod", | |
"cyan", | |
"seagreen", | |
"salmon", | |
"steelblue", | |
"orange", | |
] | |
# Plot shaded areas | |
x_values = [0, 1, 2] # Example x values, adjust as needed | |
for i, (start, end) in enumerate(z_ranges): | |
ax1.fill_betweenx(x_values, start, end, color=colors[i], alpha=0.5) | |
# Middle panel (equally thick) | |
ax2 = plt.subplot(gs[1]) | |
for i, (start, end) in enumerate(z_ranges): | |
dfplot_z = dfs["z"][(dfs["z"]["ztarget"] > start) & (dfs["z"]["ztarget"] < end)] | |
ax2.hist( | |
dfplot_z.ztarget, | |
bins=50, | |
color=colors[i], | |
histtype="step", | |
linestyle="-", | |
density=True, | |
range=(0, 4), | |
) | |
# Bottom panel (equally thick) | |
ax3 = plt.subplot(gs[2]) | |
for i, (start, end) in enumerate(z_ranges): | |
dfplot_z = dfs["z"][(dfs["z"]["z"] > start) & (dfs["z"]["z"] < end)] | |
ax3.hist( | |
dfplot_z.ztarget, | |
bins=50, | |
color=colors[i], | |
histtype="step", | |
linestyle="-", | |
density=True, | |
range=(0, 4), | |
) | |
# Bottom panel (equally thick) | |
ax4 = plt.subplot(gs[3]) | |
for i, (start, end) in enumerate(z_ranges): | |
dfplot_z = dfs["L15"][(dfs["L15"]["z"] > start) & (dfs["L15"]["z"] < end)] | |
print(len(dfplot_z)) | |
ax4.hist( | |
dfplot_z.ztarget, | |
bins=50, | |
color=colors[i], | |
histtype="step", | |
linestyle="-", | |
density=True, | |
range=(0, 4), | |
) | |
ax5 = plt.subplot(gs[4]) | |
for i, (start, end) in enumerate(z_ranges): | |
dfplot_z = dfs["DA"][(dfs["DA"]["z"] > start) & (dfs["DA"]["z"] < end)] | |
ax5.hist( | |
dfplot_z.ztarget, | |
bins=50, | |
color=colors[i], | |
histtype="step", | |
linestyle="-", | |
density=True, | |
range=(0, 4), | |
) | |
plt.tight_layout() | |
plt.show() | |
# - | |
def plot_nz(df_list, zcuts=[0.1, 0.5, 1, 1.5, 2, 3, 4], save=False): | |
# Plot properties | |
plt.rcParams["font.family"] = "serif" | |
plt.rcParams["font.size"] = 16 | |
cmap = plt.get_cmap("Dark2") # Choose a colormap for coloring lines | |
# Create subplots | |
fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True) | |
for i, df in enumerate(df_list): | |
dfplot = df_list[i].copy() # Assuming df_list contains dataframes | |
ax = axs[i] # Selecting the appropriate subplot | |
for iz in range(len(zcuts) - 1): | |
dfplot_z = dfplot[ | |
(dfplot["ztarget"] > zcuts[iz]) & (dfplot["ztarget"] < zcuts[iz + 1]) | |
] | |
color = cmap(iz) # Get a different color for each redshift | |
zt_mean = np.median(dfplot_z.ztarget.values) | |
zp_mean = np.median(dfplot_z.z.values) | |
# Plot histogram on the selected subplot | |
ax.hist( | |
dfplot_z.z, | |
bins=50, | |
color=color, | |
histtype="step", | |
linestyle="-", | |
density=True, | |
range=(0, 4), | |
) | |
ax.axvline(zt_mean, color=color, linestyle="-", lw=2) | |
ax.axvline(zp_mean, color=color, linestyle="--", lw=2) | |
ax.set_ylabel(f"Frequency", fontsize=14) | |
ax.grid(False) | |
ax.set_xlim(0, 3.5) | |
axs[-1].set_xlabel(f"$z$", fontsize=18) | |
if save: | |
plt.savefig(f"nz_hist.pdf", dpi=300, bbox_inches="tight") | |
plt.show() | |
plot_nz(df_list) | |