File size: 3,249 Bytes
b25063d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: temps
#     language: python
#     name: temps
# ---

# # $p(z)$ examples

# ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES

# ### LOAD PYTHON MODULES

# %load_ext autoreload
# %autoreload 2

import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path

#matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"

from temps.archive import Archive 
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule


# ### LOAD DATA

#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')

filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
path_file = parent_dir / filename_valid  # Creating the path to the file
hdu_list = fits.open(path_file)
cat = Table(hdu_list[1].data).to_pandas()
cat = cat[cat['FLAG_PHOT']==0]
cat = cat[cat['mu_class_L07']==1]
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
cat = cat[cat['MAG_VIS']<25]


ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
ID = cat['ID']
VISmag = cat['MAG_VIS']
zsflag = cat['reliable_S15']

photoz_archive = Archive(path = parent_dir,only_zspec=False)
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
col, colerr = photoz_archive._to_colors(f, ferr)

# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES

# The notebook 'Tutorial_temps' gives an example of how to train and save models.

# Initialize an empty dictionary to store DataFrames
ii = np.random.randint(0,len(col),1)
pz_dict = {}
for il, lab in enumerate(['z','L15','DA']):
    
    nn_features = EncoderPhotometry()
    nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
    nn_z = MeasureZ(num_gauss=6)
    nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
    
    temps_module = TempsModule(nn_features, nn_z)
    
    
    z, pz, fodds = temps_module.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)

    
    # Assign the DataFrame to a key in the dictionary
    pz_dict[lab] = pz


# +
cmap = plt.get_cmap('Dark2') 

plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
plt.plot(np.linspace(0,5,1000),pz_dict['L15'][0],label='L15', color = cmap(1), ls =':')
plt.plot(np.linspace(0,5,1000),pz_dict['DA'][0],label='TEMPS', color = cmap(2), ls ='-')
plt.axvline(x=np.array(ztarget)[ii][0],ls='-.',color='black')
#plt.xlim(0,2)
plt.legend()

plt.xlabel(r'$z$', fontsize=14)
plt.ylabel('Probability', fontsize=14)
#plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
# -