Spaces:
Runtime error
Runtime error
# --- | |
# jupyter: | |
# jupytext: | |
# text_representation: | |
# extension: .py | |
# format_name: light | |
# format_version: '1.5' | |
# jupytext_version: 1.16.2 | |
# kernelspec: | |
# display_name: temps | |
# language: python | |
# name: temps | |
# --- | |
# # $p(z)$ DISTRIBUTIONS | |
# ## PIT AND CRPS FOR THE THREE METHODS | |
# ### LOAD PYTHON MODULES | |
# %load_ext autoreload | |
# %autoreload 2 | |
import temps | |
import pandas as pd | |
import numpy as np | |
import os | |
from astropy.io import fits | |
from astropy.table import Table | |
import torch | |
from pathlib import Path | |
#matplotlib settings | |
from matplotlib import rcParams | |
import matplotlib.pyplot as plt | |
rcParams["mathtext.fontset"] = "stix" | |
rcParams["font.family"] = "STIXGeneral" | |
# + | |
from temps.temps import TempsModule | |
from temps.archive import Archive | |
from temps.utils import nmad | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
from temps.plots import plot_photoz, plot_PIT, plot_crps | |
# - | |
# ### LOAD DATA | |
#define here the directory containing the photometric catalogues | |
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5') | |
modules_dir = Path('../data/models/') | |
photoz_archive = Archive(path = parent_dir, | |
only_zspec=False, | |
flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ], | |
target_test='L15') | |
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data() | |
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE | |
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook. | |
# Initialize an empty dictionary to store DataFrames | |
crps_dict = {} | |
pit_dict = {} | |
for il, lab in enumerate(['z','L15','DA']): | |
nn_features = EncoderPhotometry() | |
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu'))) | |
nn_z = MeasureZ(num_gauss=6) | |
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu'))) | |
temps_module = TempsModule(nn_features, nn_z) | |
pit_list = temps_module.calculate_pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test)) | |
crps_list = temps_module.calculate_crps(input_data=torch.Tensor(f_test), target_data=specz_test) | |
# Assign the DataFrame to a key in the dictionary | |
crps_dict[lab] = crps_list | |
pit_dict[lab] = pit_list | |
# + | |
plot_PIT(pit_dict['z'], | |
pit_dict['L15'], | |
pit_dict['DA'], | |
labels=[r'$z_{rm s}$', 'L15', 'TEMPS'], | |
sample='L15', | |
save=True) | |
# + | |
plot_crps(crps_dict['z'], | |
crps_dict['L15'], | |
crps_dict['DA'], | |
labels=[r'$z_{\rm s}$', 'L15', 'TEMPS'], | |
sample = 'L15', | |
save=True) | |
# - | |