File size: 3,049 Bytes
b25063d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668e440
b25063d
 
668e440
b25063d
 
 
 
 
668e440
b25063d
 
 
 
 
 
 
 
 
668e440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b25063d
 
 
 
 
 
 
 
 
668e440
 
b25063d
668e440
 
 
b25063d
668e440
 
 
 
b25063d
 
668e440
 
 
 
 
 
 
b25063d
 
 
 
 
 
668e440
 
 
 
 
 
 
 
b25063d
 
 
668e440
 
 
 
 
 
 
 
b25063d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: temps
#     language: python
#     name: temps
# ---

# # $p(z)$ DISTRIBUTIONS

# ## PIT AND CRPS FOR THE THREE METHODS

# ### LOAD PYTHON MODULES

# %load_ext autoreload
# %autoreload 2

import temps

import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path

# matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt

rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"

# +
from temps.temps import TempsModule
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.plots import plot_photoz, plot_PIT, plot_crps


# -

# ### LOAD DATA

# define here the directory containing the photometric catalogues
parent_dir = Path(
    "/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5"
)
modules_dir = Path("../data/models/")

photoz_archive = Archive(
    path=parent_dir,
    only_zspec=False,
    flags_kept=[
        1.0,
        1.1,
        1.4,
        1.5,
        2,
        2.1,
        2.4,
        2.5,
        3.0,
        3.1,
        3.4,
        3.5,
        4.0,
        9.0,
        9.1,
        9.3,
        9.4,
        9.5,
        11.1,
        11.5,
        12.1,
        12.5,
        13.0,
        13.1,
        13.5,
        14,
    ],
    target_test="L15",
)
f_test, ferr_test, specz_test, VIS_mag_test = photoz_archive.get_testing_data()


# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE

# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.

# Initialize an empty dictionary to store DataFrames
crps_dict = {}
pit_dict = {}
for il, lab in enumerate(["z", "L15", "DA"]):

    nn_features = EncoderPhotometry()
    nn_features.load_state_dict(
        torch.load(modules_dir / f"modelF_{lab}.pt", map_location=torch.device("cpu"))
    )
    nn_z = MeasureZ(num_gauss=6)
    nn_z.load_state_dict(
        torch.load(modules_dir / f"modelZ_{lab}.pt", map_location=torch.device("cpu"))
    )

    temps_module = TempsModule(nn_features, nn_z)

    pit_list = temps_module.calculate_pit(
        input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test)
    )
    crps_list = temps_module.calculate_crps(
        input_data=torch.Tensor(f_test), target_data=specz_test
    )

    # Assign the DataFrame to a key in the dictionary
    crps_dict[lab] = crps_list
    pit_dict[lab] = pit_list


# +
plot_PIT(
    pit_dict["z"],
    pit_dict["L15"],
    pit_dict["DA"],
    labels=[r"$z_{rm s}$", "L15", "TEMPS"],
    sample="L15",
    save=True,
)


# +
plot_crps(
    crps_dict["z"],
    crps_dict["L15"],
    crps_dict["DA"],
    labels=[r"$z_{\rm s}$", "L15", "TEMPS"],
    sample="L15",
    save=True,
)


# -