TEMPS / notebooks /PIT_CRPS.py
Laura Cabayol Garcia
running precommit
668e440
raw
history blame
3.05 kB
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: temps
# language: python
# name: temps
# ---
# # $p(z)$ DISTRIBUTIONS
# ## PIT AND CRPS FOR THE THREE METHODS
# ### LOAD PYTHON MODULES
# %load_ext autoreload
# %autoreload 2
import temps
import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path
# matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
# +
from temps.temps import TempsModule
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.plots import plot_photoz, plot_PIT, plot_crps
# -
# ### LOAD DATA
# define here the directory containing the photometric catalogues
parent_dir = Path(
"/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5"
)
modules_dir = Path("../data/models/")
photoz_archive = Archive(
path=parent_dir,
only_zspec=False,
flags_kept=[
1.0,
1.1,
1.4,
1.5,
2,
2.1,
2.4,
2.5,
3.0,
3.1,
3.4,
3.5,
4.0,
9.0,
9.1,
9.3,
9.4,
9.5,
11.1,
11.5,
12.1,
12.5,
13.0,
13.1,
13.5,
14,
],
target_test="L15",
)
f_test, ferr_test, specz_test, VIS_mag_test = photoz_archive.get_testing_data()
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
# Initialize an empty dictionary to store DataFrames
crps_dict = {}
pit_dict = {}
for il, lab in enumerate(["z", "L15", "DA"]):
nn_features = EncoderPhotometry()
nn_features.load_state_dict(
torch.load(modules_dir / f"modelF_{lab}.pt", map_location=torch.device("cpu"))
)
nn_z = MeasureZ(num_gauss=6)
nn_z.load_state_dict(
torch.load(modules_dir / f"modelZ_{lab}.pt", map_location=torch.device("cpu"))
)
temps_module = TempsModule(nn_features, nn_z)
pit_list = temps_module.calculate_pit(
input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test)
)
crps_list = temps_module.calculate_crps(
input_data=torch.Tensor(f_test), target_data=specz_test
)
# Assign the DataFrame to a key in the dictionary
crps_dict[lab] = crps_list
pit_dict[lab] = pit_list
# +
plot_PIT(
pit_dict["z"],
pit_dict["L15"],
pit_dict["DA"],
labels=[r"$z_{rm s}$", "L15", "TEMPS"],
sample="L15",
save=True,
)
# +
plot_crps(
crps_dict["z"],
crps_dict["L15"],
crps_dict["DA"],
labels=[r"$z_{\rm s}$", "L15", "TEMPS"],
sample="L15",
save=True,
)
# -