TEMPS / notebooks /PIT_CRPS.py
lauracabayol's picture
new notebook versions
b25063d
raw
history blame
2.89 kB
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: temps
# language: python
# name: temps
# ---
# # $p(z)$ DISTRIBUTIONS
# ## PIT AND CRPS FOR THE THREE METHODS
# ### LOAD PYTHON MODULES
# %load_ext autoreload
# %autoreload 2
import temps
import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path
#matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
# +
from temps.temps import TempsModule
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.plots import plot_photoz, plot_PIT, plot_crps
# -
# ### LOAD DATA
#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')
photoz_archive = Archive(path = parent_dir,
only_zspec=False,
flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
target_test='L15')
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
# Initialize an empty dictionary to store DataFrames
crps_dict = {}
pit_dict = {}
for il, lab in enumerate(['z','L15','DA']):
nn_features = EncoderPhotometry()
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
nn_z = MeasureZ(num_gauss=6)
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
temps_module = TempsModule(nn_features, nn_z)
pit_list = temps_module.calculate_pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
crps_list = temps_module.calculate_crps(input_data=torch.Tensor(f_test), target_data=specz_test)
# Assign the DataFrame to a key in the dictionary
crps_dict[lab] = crps_list
pit_dict[lab] = pit_list
# +
plot_PIT(pit_dict['z'],
pit_dict['L15'],
pit_dict['DA'],
labels=[r'$z_{rm s}$', 'L15', 'TEMPS'],
sample='L15',
save=True)
# +
plot_crps(crps_dict['z'],
crps_dict['L15'],
crps_dict['DA'],
labels=[r'$z_{\rm s}$', 'L15', 'TEMPS'],
sample = 'L15',
save=True)
# -