Spaces:
Runtime error
Runtime error
Commit
·
c212435
0
Parent(s):
Archive code and network training
Browse files- insight/.ipynb_checkpoints/archive-checkpoint.py +211 -0
- insight/.ipynb_checkpoints/insight-checkpoint.py +166 -0
- insight/.ipynb_checkpoints/insight_arch-checkpoint.py +81 -0
- insight/.ipynb_checkpoints/utils-checkpoint.py +51 -0
- insight/__pycache__/archive.cpython-310.pyc +0 -0
- insight/__pycache__/archive.cpython-39.pyc +0 -0
- insight/__pycache__/insight.cpython-310.pyc +0 -0
- insight/__pycache__/insight.cpython-39.pyc +0 -0
- insight/__pycache__/insight_arch.cpython-310.pyc +0 -0
- insight/__pycache__/insight_arch.cpython-39.pyc +0 -0
- insight/__pycache__/utils.cpython-310.pyc +0 -0
- insight/__pycache__/utils.cpython-39.pyc +0 -0
- insight/archive.py +211 -0
- insight/insight.py +166 -0
- insight/insight_arch.py +81 -0
- insight/utils.py +51 -0
insight/.ipynb_checkpoints/archive-checkpoint.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from astropy.io import fits
|
4 |
+
import os
|
5 |
+
from astropy.table import Table
|
6 |
+
from scipy.spatial import KDTree
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from matplotlib import rcParams
|
11 |
+
rcParams["mathtext.fontset"] = "stix"
|
12 |
+
rcParams["font.family"] = "STIXGeneral"
|
13 |
+
|
14 |
+
|
15 |
+
class archive():
|
16 |
+
def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
|
17 |
+
|
18 |
+
self.aperture = aperture
|
19 |
+
|
20 |
+
self.weight_dict={(-99,0.99):0,
|
21 |
+
(1,1.99):0.5,
|
22 |
+
(2,2.99):0.75,
|
23 |
+
(3,4):1,
|
24 |
+
(9,9.99):0.25,
|
25 |
+
(10,10.99):0,
|
26 |
+
(11,11.99):0.5,
|
27 |
+
(12,12.99):0.75,
|
28 |
+
(13,14):1,
|
29 |
+
(14.01,40):0
|
30 |
+
}
|
31 |
+
|
32 |
+
filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
33 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
34 |
+
filename_gold='Export_Gold_2023_07_03.csv'
|
35 |
+
|
36 |
+
hdu_list = fits.open(os.path.join(path,filename_calib))
|
37 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
38 |
+
|
39 |
+
hdu_list = fits.open(os.path.join(path,filename_valid))
|
40 |
+
cat_test = Table(hdu_list[1].data).to_pandas()
|
41 |
+
|
42 |
+
gold_sample = pd.read_csv(os.path.join(path,filename_gold))
|
43 |
+
|
44 |
+
#cat_test = self._match_gold_sample(cat_test,gold_sample)
|
45 |
+
|
46 |
+
if drop_stars==True:
|
47 |
+
cat = cat[cat.mu_class_L07==1]
|
48 |
+
|
49 |
+
if clean_photometry==True:
|
50 |
+
cat = self._clean_photometry(cat)
|
51 |
+
cat_test = self._clean_photometry(cat_test)
|
52 |
+
|
53 |
+
self._get_loss_weights(cat)
|
54 |
+
|
55 |
+
cat = cat[cat.w_Q_f_S15>0]
|
56 |
+
|
57 |
+
self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
|
58 |
+
|
59 |
+
|
60 |
+
self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
|
61 |
+
|
62 |
+
self._get_loss_weights(cat)
|
63 |
+
|
64 |
+
#self.cat_test=cat_test
|
65 |
+
#self.cat_train=cat
|
66 |
+
|
67 |
+
def _extract_fluxes(self,catalogue):
|
68 |
+
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
69 |
+
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
70 |
+
|
71 |
+
f = catalogue[columns_f].values
|
72 |
+
ferr = catalogue[columns_ferr].values
|
73 |
+
return f, ferr
|
74 |
+
|
75 |
+
def _to_colors(self, flux, fluxerr):
|
76 |
+
""" Convert fluxes to colors"""
|
77 |
+
color = flux[:,:-1] / flux[:,1:]
|
78 |
+
color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
|
79 |
+
return color,color_err
|
80 |
+
|
81 |
+
def _clean_photometry(self,catalogue):
|
82 |
+
""" Drops all object with FLAG_PHOT!=0"""
|
83 |
+
catalogue = catalogue[catalogue['FLAG_PHOT']==0]
|
84 |
+
|
85 |
+
return catalogue
|
86 |
+
|
87 |
+
def _correct_extinction(self,catalogue, f):
|
88 |
+
"""Corrects for extinction"""
|
89 |
+
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
|
90 |
+
ext_correction = catalogue[ext_correction_cols].values
|
91 |
+
|
92 |
+
f = f * ext_correction
|
93 |
+
return f
|
94 |
+
|
95 |
+
def _take_only_zspec(self,catalogue,cat_flag=None):
|
96 |
+
"""Selects only galaxies with spectroscopic redshift"""
|
97 |
+
if cat_flag=='Calib':
|
98 |
+
catalogue = catalogue[catalogue.z_spec_S15>0]
|
99 |
+
elif cat_flag=='Valid':
|
100 |
+
catalogue = catalogue[catalogue.z_spec_S15>0]
|
101 |
+
return catalogue
|
102 |
+
|
103 |
+
def _clean_zspec_sample(self,catalogue ,kind=None):
|
104 |
+
if kind==None:
|
105 |
+
return catalogue
|
106 |
+
elif kind=='Total':
|
107 |
+
return catalogue[catalogue['reliable_S15']>0]
|
108 |
+
elif kind=='Partial':
|
109 |
+
return catalogue[(catalogue['w_Q_f_S15']>0.5)]
|
110 |
+
|
111 |
+
def _map_weight(self,Qz):
|
112 |
+
for key, value in self.weight_dict.items():
|
113 |
+
if key[0] <= Qz <= key[1]:
|
114 |
+
return value
|
115 |
+
|
116 |
+
def _get_loss_weights(self,catalogue):
|
117 |
+
catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight)
|
118 |
+
|
119 |
+
def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
|
120 |
+
max_distance_deg = max_distance_arcsec / 3600.0
|
121 |
+
|
122 |
+
gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
|
123 |
+
valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
|
124 |
+
|
125 |
+
kdtree = KDTree(gold_sample_radec)
|
126 |
+
distances, indices = kdtree.query(valid_sample_radec, k=1)
|
127 |
+
|
128 |
+
specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
|
129 |
+
|
130 |
+
zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
|
131 |
+
|
132 |
+
catalogue_valid['z_spec_gold'] = zs
|
133 |
+
|
134 |
+
return catalogue_valid
|
135 |
+
|
136 |
+
|
137 |
+
def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
|
138 |
+
|
139 |
+
if only_zspec:
|
140 |
+
catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
|
141 |
+
catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
|
142 |
+
|
143 |
+
self.cat_train=catalogue
|
144 |
+
f, ferr = self._extract_fluxes(catalogue)
|
145 |
+
|
146 |
+
|
147 |
+
if extinction_corr==True:
|
148 |
+
f = self._correct_extinction(catalogue,f)
|
149 |
+
|
150 |
+
if convert_colors==True:
|
151 |
+
col, colerr = self._to_colors(f, ferr)
|
152 |
+
|
153 |
+
self.phot_train = col
|
154 |
+
self.photerr_train = colerr
|
155 |
+
else:
|
156 |
+
self.phot_train = f
|
157 |
+
self.photerr_train = ferr
|
158 |
+
|
159 |
+
self.target_z_train = catalogue['z_spec_S15'].values
|
160 |
+
self.target_qz_train = catalogue['w_Q_f_S15'].values
|
161 |
+
|
162 |
+
def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
|
163 |
+
|
164 |
+
if only_zspec:
|
165 |
+
catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
|
166 |
+
catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
|
167 |
+
|
168 |
+
self.cat_test=catalogue
|
169 |
+
|
170 |
+
f, ferr = self._extract_fluxes(catalogue)
|
171 |
+
|
172 |
+
|
173 |
+
if extinction_corr==True:
|
174 |
+
f = self._correct_extinction(catalogue,f)
|
175 |
+
|
176 |
+
if convert_colors==True:
|
177 |
+
col, colerr = self._to_colors(f, ferr)
|
178 |
+
self.phot_test = col
|
179 |
+
self.photerr_test = colerr
|
180 |
+
else:
|
181 |
+
self.phot_test = f
|
182 |
+
self.photerr_test = ferr
|
183 |
+
|
184 |
+
self.target_z_test = catalogue['z_spec_S15'].values
|
185 |
+
|
186 |
+
|
187 |
+
def get_training_data(self):
|
188 |
+
return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train
|
189 |
+
|
190 |
+
def get_testing_data(self):
|
191 |
+
return self.phot_test, self.photerr_test, self.target_z_test
|
192 |
+
|
193 |
+
def get_VIS_mag(self, catalogue):
|
194 |
+
return catalogue[['MAG_VIS']].values
|
195 |
+
|
196 |
+
def plot_zdistribution(self, plot_test=False, bins=50):
|
197 |
+
_,_,specz = photoz_archive.get_training_data()
|
198 |
+
plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
|
199 |
+
|
200 |
+
if plot_test:
|
201 |
+
_,_,specz_test = photoz_archive.get_training_data()
|
202 |
+
plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
|
203 |
+
|
204 |
+
|
205 |
+
plt.xticks(fontsize=12)
|
206 |
+
plt.yticks(fontsize=12)
|
207 |
+
|
208 |
+
plt.xlabel(r'Redshift', fontsize=14)
|
209 |
+
plt.ylabel('Counts', fontsize=14)
|
210 |
+
|
211 |
+
plt.show()
|
insight/.ipynb_checkpoints/insight-checkpoint.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, dataset, TensorDataset
|
3 |
+
from torch import nn, optim
|
4 |
+
from torch.optim import lr_scheduler
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from astropy.io import fits
|
8 |
+
import os
|
9 |
+
from astropy.table import Table
|
10 |
+
from scipy.spatial import KDTree
|
11 |
+
|
12 |
+
class Insight_module():
|
13 |
+
""" Define class"""
|
14 |
+
|
15 |
+
def __init__(self, model):
|
16 |
+
self.model=model
|
17 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
|
19 |
+
def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
|
20 |
+
input_data = torch.Tensor(input_data)
|
21 |
+
target_data = torch.Tensor(target_data)
|
22 |
+
target_weights = torch.Tensor(target_weights)
|
23 |
+
|
24 |
+
dataset = TensorDataset(input_data, target_data, target_weights)
|
25 |
+
|
26 |
+
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
27 |
+
loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
|
28 |
+
loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
|
29 |
+
|
30 |
+
return loader_train, loader_val
|
31 |
+
|
32 |
+
|
33 |
+
def _loss_function(self,mean, std, logmix, true, target_weights):
|
34 |
+
log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
|
35 |
+
|
36 |
+
log_prob = torch.logsumexp(log_prob, 1)
|
37 |
+
|
38 |
+
#log_prob = log_prob * target_weights
|
39 |
+
loss = -log_prob.mean()
|
40 |
+
|
41 |
+
return loss
|
42 |
+
|
43 |
+
def _to_numpy(self,x):
|
44 |
+
return x.detach().cpu().numpy()
|
45 |
+
|
46 |
+
def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
|
47 |
+
self.model = self.model.train()
|
48 |
+
loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
|
49 |
+
optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
|
50 |
+
|
51 |
+
self.model = self.model.to(self.device)
|
52 |
+
|
53 |
+
loss_train, loss_validation = [],[]
|
54 |
+
|
55 |
+
for epoch in range(nepochs):
|
56 |
+
for input_data, target_data, target_weights in loader_train:
|
57 |
+
|
58 |
+
input_data = input_data.to(self.device)
|
59 |
+
target_data = target_data.to(self.device)
|
60 |
+
target_weights = target_weights.to(self.device)
|
61 |
+
|
62 |
+
|
63 |
+
optimizer.zero_grad()
|
64 |
+
|
65 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
66 |
+
logsig = torch.clamp(logsig,-6,2)
|
67 |
+
sig = torch.exp(logsig)
|
68 |
+
|
69 |
+
|
70 |
+
#print(mu,sig,target_data,torch.exp(logmix_coeff))
|
71 |
+
|
72 |
+
loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
|
73 |
+
|
74 |
+
loss.backward()
|
75 |
+
optimizer.step()
|
76 |
+
|
77 |
+
loss_train.append(loss.item())
|
78 |
+
|
79 |
+
for input_data, target_data, target_weights in loader_val:
|
80 |
+
|
81 |
+
|
82 |
+
input_data = input_data.to(self.device)
|
83 |
+
target_data = target_data.to(self.device)
|
84 |
+
target_weights = target_weights.to(self.device)
|
85 |
+
|
86 |
+
|
87 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
88 |
+
logsig = torch.clamp(logsig,-6,2)
|
89 |
+
sig = torch.exp(logsig)
|
90 |
+
|
91 |
+
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
|
92 |
+
loss_validation.append(loss_val.item())
|
93 |
+
|
94 |
+
print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
|
95 |
+
|
96 |
+
self.loss_train=loss_train
|
97 |
+
self.loss_validation=loss_validation
|
98 |
+
|
99 |
+
|
100 |
+
def get_photoz(self,input_data, target_data):
|
101 |
+
self.model = self.model.eval()
|
102 |
+
self.model = self.model.to(self.device)
|
103 |
+
|
104 |
+
input_data = input_data.to(self.device)
|
105 |
+
target_data = target_data.to(self.device)
|
106 |
+
|
107 |
+
for ii in range(len(input_data)):
|
108 |
+
|
109 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
110 |
+
logsig = torch.clamp(logsig,-6,2)
|
111 |
+
sig = torch.exp(logsig)
|
112 |
+
|
113 |
+
mix_coeff = torch.exp(logmix_coeff)
|
114 |
+
|
115 |
+
z = (mix_coeff * mu).sum(1)
|
116 |
+
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
|
117 |
+
|
118 |
+
|
119 |
+
return self._to_numpy(z),self._to_numpy(zerr)
|
120 |
+
|
121 |
+
|
122 |
+
#return model
|
123 |
+
|
124 |
+
def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
|
125 |
+
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
126 |
+
ydata,xlab = [],[]
|
127 |
+
|
128 |
+
|
129 |
+
for k in range(len(bin_edges)-1):
|
130 |
+
edge_min = bin_edges[k]
|
131 |
+
edge_max = bin_edges[k+1]
|
132 |
+
|
133 |
+
mean_mag = (edge_max + edge_min) / 2
|
134 |
+
|
135 |
+
if type_bin=='bin':
|
136 |
+
df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
|
137 |
+
elif type_bin=='cum':
|
138 |
+
df_plot = df_test[(df_test.imag < edge_max)]
|
139 |
+
else:
|
140 |
+
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
141 |
+
|
142 |
+
|
143 |
+
xlab.append(mean_mag)
|
144 |
+
if metric=='sig68':
|
145 |
+
ydata.append(sigma68(df_plot.zwerr))
|
146 |
+
elif metric=='bias':
|
147 |
+
ydata.append(np.mean(df_plot.zwerr))
|
148 |
+
elif metric=='nmad':
|
149 |
+
ydata.append(nmad(df_plot.zwerr))
|
150 |
+
elif metric=='outliers':
|
151 |
+
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
|
152 |
+
|
153 |
+
plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
|
154 |
+
plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
|
155 |
+
plt.xlabel(f'{xvariable}', fontsize = 16)
|
156 |
+
|
157 |
+
plt.xticks(fontsize = 14)
|
158 |
+
plt.yticks(fontsize = 14)
|
159 |
+
|
160 |
+
plt.grid(False)
|
161 |
+
|
162 |
+
plt.show()
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
insight/.ipynb_checkpoints/insight_arch-checkpoint.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, optim
|
2 |
+
import torch
|
3 |
+
class Photoz_network(nn.Module):
|
4 |
+
def __init__(self, num_gauss=10, dropout_prob=0):
|
5 |
+
super(Photoz_network, self).__init__()
|
6 |
+
|
7 |
+
self.features = nn.Sequential(
|
8 |
+
nn.Linear(6, 10),
|
9 |
+
nn.Dropout(dropout_prob),
|
10 |
+
nn.ReLU(),
|
11 |
+
nn.Linear(10, 30),
|
12 |
+
nn.Dropout(dropout_prob),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.Linear(30, 50),
|
15 |
+
nn.Dropout(dropout_prob),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Linear(50, 70),
|
18 |
+
nn.Dropout(dropout_prob),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.Linear(70, 100)
|
21 |
+
)
|
22 |
+
|
23 |
+
self.measure_mu = nn.Sequential(
|
24 |
+
nn.Linear(100, 80),
|
25 |
+
nn.Dropout(dropout_prob),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Linear(80, 70),
|
28 |
+
nn.Dropout(dropout_prob),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Linear(70, 60),
|
31 |
+
nn.Dropout(dropout_prob),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Linear(60, 50),
|
34 |
+
nn.Dropout(dropout_prob),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Linear(50, num_gauss)
|
37 |
+
)
|
38 |
+
|
39 |
+
self.measure_coeffs = nn.Sequential(
|
40 |
+
nn.Linear(100, 80),
|
41 |
+
nn.Dropout(dropout_prob),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Linear(80, 70),
|
44 |
+
nn.Dropout(dropout_prob),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Linear(70, 60),
|
47 |
+
nn.Dropout(dropout_prob),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(60, 50),
|
50 |
+
nn.Dropout(dropout_prob),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Linear(50, num_gauss)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.measure_sigma = nn.Sequential(
|
56 |
+
nn.Linear(100, 80),
|
57 |
+
nn.Dropout(dropout_prob),
|
58 |
+
nn.ReLU(),
|
59 |
+
nn.Linear(80, 70),
|
60 |
+
nn.Dropout(dropout_prob),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Linear(70, 60),
|
63 |
+
nn.Dropout(dropout_prob),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Linear(60, 50),
|
66 |
+
nn.Dropout(dropout_prob),
|
67 |
+
nn.ReLU(),
|
68 |
+
nn.Linear(50, num_gauss)
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
f = self.features(x)
|
73 |
+
mu = self.measure_mu(f)
|
74 |
+
sigma = self.measure_sigma(f)
|
75 |
+
logmix_coeff = self.measure_coeffs(f)
|
76 |
+
|
77 |
+
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
|
78 |
+
|
79 |
+
return mu, sigma, logmix_coeff
|
80 |
+
|
81 |
+
|
insight/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from scipy import stats
|
5 |
+
|
6 |
+
def nmad(data):
|
7 |
+
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
8 |
+
|
9 |
+
def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
|
10 |
+
|
11 |
+
|
12 |
+
def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
|
13 |
+
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
14 |
+
ydata,xlab = [],[]
|
15 |
+
|
16 |
+
|
17 |
+
for k in range(len(bin_edges)-1):
|
18 |
+
edge_min = bin_edges[k]
|
19 |
+
edge_max = bin_edges[k+1]
|
20 |
+
|
21 |
+
mean_mag = (edge_max + edge_min) / 2
|
22 |
+
|
23 |
+
if type_bin=='bin':
|
24 |
+
df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
|
25 |
+
elif type_bin=='cum':
|
26 |
+
df_plot = df_test[(df_test.imag < edge_max)]
|
27 |
+
else:
|
28 |
+
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
29 |
+
|
30 |
+
|
31 |
+
xlab.append(mean_mag)
|
32 |
+
if metric=='sig68':
|
33 |
+
ydata.append(sigma68(df_plot.zwerr))
|
34 |
+
elif metric=='bias':
|
35 |
+
ydata.append(np.mean(df_plot.zwerr))
|
36 |
+
elif metric=='nmad':
|
37 |
+
ydata.append(nmad(df_plot.zwerr))
|
38 |
+
elif metric=='outliers':
|
39 |
+
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
|
40 |
+
|
41 |
+
plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
|
42 |
+
plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
|
43 |
+
plt.xlabel(f'{xvariable}', fontsize = 16)
|
44 |
+
|
45 |
+
plt.xticks(fontsize = 14)
|
46 |
+
plt.yticks(fontsize = 14)
|
47 |
+
|
48 |
+
plt.grid(False)
|
49 |
+
|
50 |
+
plt.show()
|
51 |
+
|
insight/__pycache__/archive.cpython-310.pyc
ADDED
Binary file (6.92 kB). View file
|
|
insight/__pycache__/archive.cpython-39.pyc
ADDED
Binary file (6.2 kB). View file
|
|
insight/__pycache__/insight.cpython-310.pyc
ADDED
Binary file (4.48 kB). View file
|
|
insight/__pycache__/insight.cpython-39.pyc
ADDED
Binary file (4.4 kB). View file
|
|
insight/__pycache__/insight_arch.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
insight/__pycache__/insight_arch.cpython-39.pyc
ADDED
Binary file (1.58 kB). View file
|
|
insight/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.71 kB). View file
|
|
insight/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.67 kB). View file
|
|
insight/archive.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from astropy.io import fits
|
4 |
+
import os
|
5 |
+
from astropy.table import Table
|
6 |
+
from scipy.spatial import KDTree
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from matplotlib import rcParams
|
11 |
+
rcParams["mathtext.fontset"] = "stix"
|
12 |
+
rcParams["font.family"] = "STIXGeneral"
|
13 |
+
|
14 |
+
|
15 |
+
class archive():
|
16 |
+
def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
|
17 |
+
|
18 |
+
self.aperture = aperture
|
19 |
+
|
20 |
+
self.weight_dict={(-99,0.99):0,
|
21 |
+
(1,1.99):0.5,
|
22 |
+
(2,2.99):0.75,
|
23 |
+
(3,4):1,
|
24 |
+
(9,9.99):0.25,
|
25 |
+
(10,10.99):0,
|
26 |
+
(11,11.99):0.5,
|
27 |
+
(12,12.99):0.75,
|
28 |
+
(13,14):1,
|
29 |
+
(14.01,40):0
|
30 |
+
}
|
31 |
+
|
32 |
+
filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
33 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
34 |
+
filename_gold='Export_Gold_2023_07_03.csv'
|
35 |
+
|
36 |
+
hdu_list = fits.open(os.path.join(path,filename_calib))
|
37 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
38 |
+
|
39 |
+
hdu_list = fits.open(os.path.join(path,filename_valid))
|
40 |
+
cat_test = Table(hdu_list[1].data).to_pandas()
|
41 |
+
|
42 |
+
gold_sample = pd.read_csv(os.path.join(path,filename_gold))
|
43 |
+
|
44 |
+
#cat_test = self._match_gold_sample(cat_test,gold_sample)
|
45 |
+
|
46 |
+
if drop_stars==True:
|
47 |
+
cat = cat[cat.mu_class_L07==1]
|
48 |
+
|
49 |
+
if clean_photometry==True:
|
50 |
+
cat = self._clean_photometry(cat)
|
51 |
+
cat_test = self._clean_photometry(cat_test)
|
52 |
+
|
53 |
+
self._get_loss_weights(cat)
|
54 |
+
|
55 |
+
cat = cat[cat.w_Q_f_S15>0]
|
56 |
+
|
57 |
+
self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
|
58 |
+
|
59 |
+
|
60 |
+
self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
|
61 |
+
|
62 |
+
self._get_loss_weights(cat)
|
63 |
+
|
64 |
+
#self.cat_test=cat_test
|
65 |
+
#self.cat_train=cat
|
66 |
+
|
67 |
+
def _extract_fluxes(self,catalogue):
|
68 |
+
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
69 |
+
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
70 |
+
|
71 |
+
f = catalogue[columns_f].values
|
72 |
+
ferr = catalogue[columns_ferr].values
|
73 |
+
return f, ferr
|
74 |
+
|
75 |
+
def _to_colors(self, flux, fluxerr):
|
76 |
+
""" Convert fluxes to colors"""
|
77 |
+
color = flux[:,:-1] / flux[:,1:]
|
78 |
+
color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
|
79 |
+
return color,color_err
|
80 |
+
|
81 |
+
def _clean_photometry(self,catalogue):
|
82 |
+
""" Drops all object with FLAG_PHOT!=0"""
|
83 |
+
catalogue = catalogue[catalogue['FLAG_PHOT']==0]
|
84 |
+
|
85 |
+
return catalogue
|
86 |
+
|
87 |
+
def _correct_extinction(self,catalogue, f):
|
88 |
+
"""Corrects for extinction"""
|
89 |
+
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
|
90 |
+
ext_correction = catalogue[ext_correction_cols].values
|
91 |
+
|
92 |
+
f = f * ext_correction
|
93 |
+
return f
|
94 |
+
|
95 |
+
def _take_only_zspec(self,catalogue,cat_flag=None):
|
96 |
+
"""Selects only galaxies with spectroscopic redshift"""
|
97 |
+
if cat_flag=='Calib':
|
98 |
+
catalogue = catalogue[catalogue.z_spec_S15>0]
|
99 |
+
elif cat_flag=='Valid':
|
100 |
+
catalogue = catalogue[catalogue.z_spec_S15>0]
|
101 |
+
return catalogue
|
102 |
+
|
103 |
+
def _clean_zspec_sample(self,catalogue ,kind=None):
|
104 |
+
if kind==None:
|
105 |
+
return catalogue
|
106 |
+
elif kind=='Total':
|
107 |
+
return catalogue[catalogue['reliable_S15']>0]
|
108 |
+
elif kind=='Partial':
|
109 |
+
return catalogue[(catalogue['w_Q_f_S15']>0.5)]
|
110 |
+
|
111 |
+
def _map_weight(self,Qz):
|
112 |
+
for key, value in self.weight_dict.items():
|
113 |
+
if key[0] <= Qz <= key[1]:
|
114 |
+
return value
|
115 |
+
|
116 |
+
def _get_loss_weights(self,catalogue):
|
117 |
+
catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight)
|
118 |
+
|
119 |
+
def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
|
120 |
+
max_distance_deg = max_distance_arcsec / 3600.0
|
121 |
+
|
122 |
+
gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
|
123 |
+
valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
|
124 |
+
|
125 |
+
kdtree = KDTree(gold_sample_radec)
|
126 |
+
distances, indices = kdtree.query(valid_sample_radec, k=1)
|
127 |
+
|
128 |
+
specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
|
129 |
+
|
130 |
+
zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
|
131 |
+
|
132 |
+
catalogue_valid['z_spec_gold'] = zs
|
133 |
+
|
134 |
+
return catalogue_valid
|
135 |
+
|
136 |
+
|
137 |
+
def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
|
138 |
+
|
139 |
+
if only_zspec:
|
140 |
+
catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
|
141 |
+
catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
|
142 |
+
|
143 |
+
self.cat_train=catalogue
|
144 |
+
f, ferr = self._extract_fluxes(catalogue)
|
145 |
+
|
146 |
+
|
147 |
+
if extinction_corr==True:
|
148 |
+
f = self._correct_extinction(catalogue,f)
|
149 |
+
|
150 |
+
if convert_colors==True:
|
151 |
+
col, colerr = self._to_colors(f, ferr)
|
152 |
+
|
153 |
+
self.phot_train = col
|
154 |
+
self.photerr_train = colerr
|
155 |
+
else:
|
156 |
+
self.phot_train = f
|
157 |
+
self.photerr_train = ferr
|
158 |
+
|
159 |
+
self.target_z_train = catalogue['z_spec_S15'].values
|
160 |
+
self.target_qz_train = catalogue['w_Q_f_S15'].values
|
161 |
+
|
162 |
+
def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
|
163 |
+
|
164 |
+
if only_zspec:
|
165 |
+
catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
|
166 |
+
catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
|
167 |
+
|
168 |
+
self.cat_test=catalogue
|
169 |
+
|
170 |
+
f, ferr = self._extract_fluxes(catalogue)
|
171 |
+
|
172 |
+
|
173 |
+
if extinction_corr==True:
|
174 |
+
f = self._correct_extinction(catalogue,f)
|
175 |
+
|
176 |
+
if convert_colors==True:
|
177 |
+
col, colerr = self._to_colors(f, ferr)
|
178 |
+
self.phot_test = col
|
179 |
+
self.photerr_test = colerr
|
180 |
+
else:
|
181 |
+
self.phot_test = f
|
182 |
+
self.photerr_test = ferr
|
183 |
+
|
184 |
+
self.target_z_test = catalogue['z_spec_S15'].values
|
185 |
+
|
186 |
+
|
187 |
+
def get_training_data(self):
|
188 |
+
return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train
|
189 |
+
|
190 |
+
def get_testing_data(self):
|
191 |
+
return self.phot_test, self.photerr_test, self.target_z_test
|
192 |
+
|
193 |
+
def get_VIS_mag(self, catalogue):
|
194 |
+
return catalogue[['MAG_VIS']].values
|
195 |
+
|
196 |
+
def plot_zdistribution(self, plot_test=False, bins=50):
|
197 |
+
_,_,specz = photoz_archive.get_training_data()
|
198 |
+
plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
|
199 |
+
|
200 |
+
if plot_test:
|
201 |
+
_,_,specz_test = photoz_archive.get_training_data()
|
202 |
+
plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
|
203 |
+
|
204 |
+
|
205 |
+
plt.xticks(fontsize=12)
|
206 |
+
plt.yticks(fontsize=12)
|
207 |
+
|
208 |
+
plt.xlabel(r'Redshift', fontsize=14)
|
209 |
+
plt.ylabel('Counts', fontsize=14)
|
210 |
+
|
211 |
+
plt.show()
|
insight/insight.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, dataset, TensorDataset
|
3 |
+
from torch import nn, optim
|
4 |
+
from torch.optim import lr_scheduler
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from astropy.io import fits
|
8 |
+
import os
|
9 |
+
from astropy.table import Table
|
10 |
+
from scipy.spatial import KDTree
|
11 |
+
|
12 |
+
class Insight_module():
|
13 |
+
""" Define class"""
|
14 |
+
|
15 |
+
def __init__(self, model):
|
16 |
+
self.model=model
|
17 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
|
19 |
+
def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
|
20 |
+
input_data = torch.Tensor(input_data)
|
21 |
+
target_data = torch.Tensor(target_data)
|
22 |
+
target_weights = torch.Tensor(target_weights)
|
23 |
+
|
24 |
+
dataset = TensorDataset(input_data, target_data, target_weights)
|
25 |
+
|
26 |
+
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
27 |
+
loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
|
28 |
+
loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
|
29 |
+
|
30 |
+
return loader_train, loader_val
|
31 |
+
|
32 |
+
|
33 |
+
def _loss_function(self,mean, std, logmix, true, target_weights):
|
34 |
+
log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
|
35 |
+
|
36 |
+
log_prob = torch.logsumexp(log_prob, 1)
|
37 |
+
|
38 |
+
#log_prob = log_prob * target_weights
|
39 |
+
loss = -log_prob.mean()
|
40 |
+
|
41 |
+
return loss
|
42 |
+
|
43 |
+
def _to_numpy(self,x):
|
44 |
+
return x.detach().cpu().numpy()
|
45 |
+
|
46 |
+
def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
|
47 |
+
self.model = self.model.train()
|
48 |
+
loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
|
49 |
+
optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
|
50 |
+
|
51 |
+
self.model = self.model.to(self.device)
|
52 |
+
|
53 |
+
loss_train, loss_validation = [],[]
|
54 |
+
|
55 |
+
for epoch in range(nepochs):
|
56 |
+
for input_data, target_data, target_weights in loader_train:
|
57 |
+
|
58 |
+
input_data = input_data.to(self.device)
|
59 |
+
target_data = target_data.to(self.device)
|
60 |
+
target_weights = target_weights.to(self.device)
|
61 |
+
|
62 |
+
|
63 |
+
optimizer.zero_grad()
|
64 |
+
|
65 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
66 |
+
logsig = torch.clamp(logsig,-6,2)
|
67 |
+
sig = torch.exp(logsig)
|
68 |
+
|
69 |
+
|
70 |
+
#print(mu,sig,target_data,torch.exp(logmix_coeff))
|
71 |
+
|
72 |
+
loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
|
73 |
+
|
74 |
+
loss.backward()
|
75 |
+
optimizer.step()
|
76 |
+
|
77 |
+
loss_train.append(loss.item())
|
78 |
+
|
79 |
+
for input_data, target_data, target_weights in loader_val:
|
80 |
+
|
81 |
+
|
82 |
+
input_data = input_data.to(self.device)
|
83 |
+
target_data = target_data.to(self.device)
|
84 |
+
target_weights = target_weights.to(self.device)
|
85 |
+
|
86 |
+
|
87 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
88 |
+
logsig = torch.clamp(logsig,-6,2)
|
89 |
+
sig = torch.exp(logsig)
|
90 |
+
|
91 |
+
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
|
92 |
+
loss_validation.append(loss_val.item())
|
93 |
+
|
94 |
+
print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
|
95 |
+
|
96 |
+
self.loss_train=loss_train
|
97 |
+
self.loss_validation=loss_validation
|
98 |
+
|
99 |
+
|
100 |
+
def get_photoz(self,input_data, target_data):
|
101 |
+
self.model = self.model.eval()
|
102 |
+
self.model = self.model.to(self.device)
|
103 |
+
|
104 |
+
input_data = input_data.to(self.device)
|
105 |
+
target_data = target_data.to(self.device)
|
106 |
+
|
107 |
+
for ii in range(len(input_data)):
|
108 |
+
|
109 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
110 |
+
logsig = torch.clamp(logsig,-6,2)
|
111 |
+
sig = torch.exp(logsig)
|
112 |
+
|
113 |
+
mix_coeff = torch.exp(logmix_coeff)
|
114 |
+
|
115 |
+
z = (mix_coeff * mu).sum(1)
|
116 |
+
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
|
117 |
+
|
118 |
+
|
119 |
+
return self._to_numpy(z),self._to_numpy(zerr)
|
120 |
+
|
121 |
+
|
122 |
+
#return model
|
123 |
+
|
124 |
+
def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
|
125 |
+
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
126 |
+
ydata,xlab = [],[]
|
127 |
+
|
128 |
+
|
129 |
+
for k in range(len(bin_edges)-1):
|
130 |
+
edge_min = bin_edges[k]
|
131 |
+
edge_max = bin_edges[k+1]
|
132 |
+
|
133 |
+
mean_mag = (edge_max + edge_min) / 2
|
134 |
+
|
135 |
+
if type_bin=='bin':
|
136 |
+
df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
|
137 |
+
elif type_bin=='cum':
|
138 |
+
df_plot = df_test[(df_test.imag < edge_max)]
|
139 |
+
else:
|
140 |
+
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
141 |
+
|
142 |
+
|
143 |
+
xlab.append(mean_mag)
|
144 |
+
if metric=='sig68':
|
145 |
+
ydata.append(sigma68(df_plot.zwerr))
|
146 |
+
elif metric=='bias':
|
147 |
+
ydata.append(np.mean(df_plot.zwerr))
|
148 |
+
elif metric=='nmad':
|
149 |
+
ydata.append(nmad(df_plot.zwerr))
|
150 |
+
elif metric=='outliers':
|
151 |
+
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
|
152 |
+
|
153 |
+
plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
|
154 |
+
plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
|
155 |
+
plt.xlabel(f'{xvariable}', fontsize = 16)
|
156 |
+
|
157 |
+
plt.xticks(fontsize = 14)
|
158 |
+
plt.yticks(fontsize = 14)
|
159 |
+
|
160 |
+
plt.grid(False)
|
161 |
+
|
162 |
+
plt.show()
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
insight/insight_arch.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, optim
|
2 |
+
import torch
|
3 |
+
class Photoz_network(nn.Module):
|
4 |
+
def __init__(self, num_gauss=10, dropout_prob=0):
|
5 |
+
super(Photoz_network, self).__init__()
|
6 |
+
|
7 |
+
self.features = nn.Sequential(
|
8 |
+
nn.Linear(6, 10),
|
9 |
+
nn.Dropout(dropout_prob),
|
10 |
+
nn.ReLU(),
|
11 |
+
nn.Linear(10, 30),
|
12 |
+
nn.Dropout(dropout_prob),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.Linear(30, 50),
|
15 |
+
nn.Dropout(dropout_prob),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Linear(50, 70),
|
18 |
+
nn.Dropout(dropout_prob),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.Linear(70, 100)
|
21 |
+
)
|
22 |
+
|
23 |
+
self.measure_mu = nn.Sequential(
|
24 |
+
nn.Linear(100, 80),
|
25 |
+
nn.Dropout(dropout_prob),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Linear(80, 70),
|
28 |
+
nn.Dropout(dropout_prob),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Linear(70, 60),
|
31 |
+
nn.Dropout(dropout_prob),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Linear(60, 50),
|
34 |
+
nn.Dropout(dropout_prob),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Linear(50, num_gauss)
|
37 |
+
)
|
38 |
+
|
39 |
+
self.measure_coeffs = nn.Sequential(
|
40 |
+
nn.Linear(100, 80),
|
41 |
+
nn.Dropout(dropout_prob),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Linear(80, 70),
|
44 |
+
nn.Dropout(dropout_prob),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Linear(70, 60),
|
47 |
+
nn.Dropout(dropout_prob),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(60, 50),
|
50 |
+
nn.Dropout(dropout_prob),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Linear(50, num_gauss)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.measure_sigma = nn.Sequential(
|
56 |
+
nn.Linear(100, 80),
|
57 |
+
nn.Dropout(dropout_prob),
|
58 |
+
nn.ReLU(),
|
59 |
+
nn.Linear(80, 70),
|
60 |
+
nn.Dropout(dropout_prob),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Linear(70, 60),
|
63 |
+
nn.Dropout(dropout_prob),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Linear(60, 50),
|
66 |
+
nn.Dropout(dropout_prob),
|
67 |
+
nn.ReLU(),
|
68 |
+
nn.Linear(50, num_gauss)
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
f = self.features(x)
|
73 |
+
mu = self.measure_mu(f)
|
74 |
+
sigma = self.measure_sigma(f)
|
75 |
+
logmix_coeff = self.measure_coeffs(f)
|
76 |
+
|
77 |
+
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
|
78 |
+
|
79 |
+
return mu, sigma, logmix_coeff
|
80 |
+
|
81 |
+
|
insight/utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from scipy import stats
|
5 |
+
|
6 |
+
def nmad(data):
|
7 |
+
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
8 |
+
|
9 |
+
def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
|
10 |
+
|
11 |
+
|
12 |
+
def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
|
13 |
+
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
14 |
+
ydata,xlab = [],[]
|
15 |
+
|
16 |
+
|
17 |
+
for k in range(len(bin_edges)-1):
|
18 |
+
edge_min = bin_edges[k]
|
19 |
+
edge_max = bin_edges[k+1]
|
20 |
+
|
21 |
+
mean_mag = (edge_max + edge_min) / 2
|
22 |
+
|
23 |
+
if type_bin=='bin':
|
24 |
+
df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
|
25 |
+
elif type_bin=='cum':
|
26 |
+
df_plot = df_test[(df_test.imag < edge_max)]
|
27 |
+
else:
|
28 |
+
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
29 |
+
|
30 |
+
|
31 |
+
xlab.append(mean_mag)
|
32 |
+
if metric=='sig68':
|
33 |
+
ydata.append(sigma68(df_plot.zwerr))
|
34 |
+
elif metric=='bias':
|
35 |
+
ydata.append(np.mean(df_plot.zwerr))
|
36 |
+
elif metric=='nmad':
|
37 |
+
ydata.append(nmad(df_plot.zwerr))
|
38 |
+
elif metric=='outliers':
|
39 |
+
ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
|
40 |
+
|
41 |
+
plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
|
42 |
+
plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
|
43 |
+
plt.xlabel(f'{xvariable}', fontsize = 16)
|
44 |
+
|
45 |
+
plt.xticks(fontsize = 14)
|
46 |
+
plt.yticks(fontsize = 14)
|
47 |
+
|
48 |
+
plt.grid(False)
|
49 |
+
|
50 |
+
plt.show()
|
51 |
+
|