File size: 2,891 Bytes
b25063d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: temps
#     language: python
#     name: temps
# ---

# # $p(z)$ DISTRIBUTIONS

# ## PIT AND CRPS FOR THE THREE METHODS

# ### LOAD PYTHON MODULES

# %load_ext autoreload
# %autoreload 2

import temps

import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path

#matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"

# +
from temps.temps import TempsModule
from temps.archive import Archive 
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.plots import plot_photoz, plot_PIT, plot_crps


# -

# ### LOAD DATA

#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')

photoz_archive = Archive(path = parent_dir,
                         only_zspec=False,
                         flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5,  4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
                         target_test='L15')
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()


# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE

# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.

# Initialize an empty dictionary to store DataFrames
crps_dict = {}
pit_dict = {}
for il, lab in enumerate(['z','L15','DA']):
    
    nn_features = EncoderPhotometry()
    nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
    nn_z = MeasureZ(num_gauss=6)
    nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
    
    temps_module = TempsModule(nn_features, nn_z)
    
    
    pit_list = temps_module.calculate_pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
    crps_list = temps_module.calculate_crps(input_data=torch.Tensor(f_test), target_data=specz_test)

    
    # Assign the DataFrame to a key in the dictionary
    crps_dict[lab] = crps_list
    pit_dict[lab] = pit_list


# +
plot_PIT(pit_dict['z'],
         pit_dict['L15'],
         pit_dict['DA'], 
         labels=[r'$z_{rm s}$', 'L15', 'TEMPS'], 
         sample='L15',
        save=True)




# +
plot_crps(crps_dict['z'],
          crps_dict['L15'],
          crps_dict['DA'], 
          labels=[r'$z_{\rm s}$', 'L15', 'TEMPS'],
          sample = 'L15',
         save=True)



# -