Spaces:
Runtime error
Runtime error
Commit
·
57fa8fc
1
Parent(s):
692f707
clear code and notebooks
Browse files- notebooks/Feature_space.py +0 -494
- notebooks/Fig2_NMAD.py +0 -170
- notebooks/Fig3_PIT_CRPS.py +0 -120
- notebooks/Fig4_pz_examples.py +0 -128
- notebooks/Fig6_qualitycut.py +0 -164
- notebooks/Fig7_colourspace.py +0 -261
- notebooks/Table_metrics.py +0 -148
- temps/archive.py +31 -16
- temps/plots.py +41 -62
- temps/temps.py +224 -207
- temps/temps_arch.py +11 -17
- temps/utils.py +58 -144
notebooks/Feature_space.py
DELETED
@@ -1,494 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# text_representation:
|
5 |
-
# extension: .py
|
6 |
-
# format_name: light
|
7 |
-
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.14.5
|
9 |
-
# kernelspec:
|
10 |
-
# display_name: insight
|
11 |
-
# language: python
|
12 |
-
# name: insight
|
13 |
-
# ---
|
14 |
-
|
15 |
-
# # DOMAIN ADAPTATION INTUITION
|
16 |
-
|
17 |
-
# %load_ext autoreload
|
18 |
-
# %autoreload 2
|
19 |
-
|
20 |
-
import pandas as pd
|
21 |
-
import numpy as np
|
22 |
-
import os
|
23 |
-
from astropy.io import fits
|
24 |
-
from astropy.table import Table
|
25 |
-
import torch
|
26 |
-
|
27 |
-
#matplotlib settings
|
28 |
-
from matplotlib import rcParams
|
29 |
-
import matplotlib.pyplot as plt
|
30 |
-
rcParams["mathtext.fontset"] = "stix"
|
31 |
-
rcParams["font.family"] = "STIXGeneral"
|
32 |
-
|
33 |
-
# +
|
34 |
-
#insight modules
|
35 |
-
import sys
|
36 |
-
sys.path.append('../temps')
|
37 |
-
|
38 |
-
from archive import archive
|
39 |
-
from utils import nmad
|
40 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
41 |
-
from temps import Temps_module
|
42 |
-
from plots import plot_nz
|
43 |
-
# -
|
44 |
-
|
45 |
-
# ## LOAD DATA
|
46 |
-
|
47 |
-
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
49 |
-
modules_dir = '../data/models/'
|
50 |
-
|
51 |
-
# +
|
52 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
53 |
-
|
54 |
-
hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
|
55 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
56 |
-
cat = cat[cat['FLAG_PHOT']==0]
|
57 |
-
cat = cat[cat['mu_class_L07']==1]
|
58 |
-
|
59 |
-
cat['SNR_VIS'] = cat.FLUX_VIS / cat.FLUXERR_VIS
|
60 |
-
#cat = cat[cat.SNR_VIS>10]
|
61 |
-
# -
|
62 |
-
|
63 |
-
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
64 |
-
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
65 |
-
ID = cat['ID']
|
66 |
-
VISmag = cat['MAG_VIS']
|
67 |
-
zsflag = cat['reliable_S15']
|
68 |
-
cat['ztarget']=ztarget
|
69 |
-
cat['specz_or_photo']=specz_or_photo
|
70 |
-
|
71 |
-
# ### EXTRACT PHOTOMETRY
|
72 |
-
|
73 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
74 |
-
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
75 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
76 |
-
|
77 |
-
# ### MEASURE FEATURES
|
78 |
-
|
79 |
-
features_all = np.zeros((3,len(cat),10))
|
80 |
-
for il, lab in enumerate(['z','L15','DA']):
|
81 |
-
|
82 |
-
nn_features = EncoderPhotometry()
|
83 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
84 |
-
|
85 |
-
features = nn_features(torch.Tensor(col))
|
86 |
-
features = features.detach().cpu().numpy()
|
87 |
-
|
88 |
-
features_all[il]=features
|
89 |
-
|
90 |
-
|
91 |
-
# ### TRAIN AUTOENCODER TO REDUCE TO 2 DIMENSIONS
|
92 |
-
|
93 |
-
import torch
|
94 |
-
from torch import nn
|
95 |
-
class Autoencoder(nn.Module):
|
96 |
-
def __init__(self, input_dim, latent_dim):
|
97 |
-
super(Autoencoder, self).__init__()
|
98 |
-
# Encoder layers
|
99 |
-
self.encoder = nn.Sequential(
|
100 |
-
nn.Linear(input_dim, 100),
|
101 |
-
nn.ReLU(),
|
102 |
-
nn.Linear(100, 50),
|
103 |
-
nn.ReLU(),
|
104 |
-
nn.Linear(50, latent_dim)
|
105 |
-
)
|
106 |
-
# Decoder layers
|
107 |
-
self.decoder = nn.Sequential(
|
108 |
-
nn.Linear(latent_dim, 50),
|
109 |
-
nn.ReLU(),
|
110 |
-
nn.Linear(50, 100),
|
111 |
-
nn.ReLU(),
|
112 |
-
nn.Linear(100, input_dim),
|
113 |
-
)
|
114 |
-
|
115 |
-
def forward(self, x):
|
116 |
-
x = self.encoder(x)
|
117 |
-
y = self.decoder(x)
|
118 |
-
return y,x
|
119 |
-
|
120 |
-
|
121 |
-
# +
|
122 |
-
from torch.utils.data import DataLoader, dataset, TensorDataset
|
123 |
-
|
124 |
-
ds =TensorDataset(torch.Tensor(features_all[0]))
|
125 |
-
train_loader = DataLoader(ds, batch_size=100, shuffle=True, drop_last=False)
|
126 |
-
|
127 |
-
# -
|
128 |
-
|
129 |
-
import torch.optim as optim
|
130 |
-
autoencoder = Autoencoder(input_dim=10,
|
131 |
-
latent_dim=2)
|
132 |
-
criterion = nn.L1Loss()
|
133 |
-
optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
|
134 |
-
|
135 |
-
# +
|
136 |
-
# Define the number of epochs
|
137 |
-
num_epochs = 100
|
138 |
-
for epoch in range(num_epochs):
|
139 |
-
running_loss = 0.0
|
140 |
-
for data in train_loader: # Assuming 'train_loader' is your DataLoader
|
141 |
-
# Forward pass
|
142 |
-
outputs,f1 = autoencoder(data[0])
|
143 |
-
|
144 |
-
loss_autoencoder = criterion(outputs, data[0])
|
145 |
-
optimizer.zero_grad()
|
146 |
-
|
147 |
-
# Backward pass
|
148 |
-
loss_autoencoder.backward()
|
149 |
-
|
150 |
-
# Update the weights
|
151 |
-
optimizer.step()
|
152 |
-
|
153 |
-
# Accumulate the loss
|
154 |
-
running_loss += loss_autoencoder.item()
|
155 |
-
|
156 |
-
# Print the average loss for the epoch
|
157 |
-
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, running_loss / len(train_loader)))
|
158 |
-
|
159 |
-
print('Training finished')
|
160 |
-
|
161 |
-
# -
|
162 |
-
|
163 |
-
# #### EVALUTATE AUTOENCODER
|
164 |
-
|
165 |
-
# + [markdown] jupyter={"source_hidden": true}
|
166 |
-
# cat.to_csv('features_cat.csv', header=True, sep=',')
|
167 |
-
# -
|
168 |
-
|
169 |
-
indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
|
170 |
-
|
171 |
-
features_all_reduced = np.zeros(shape=(3,len(cat),2))
|
172 |
-
for i in range(3):
|
173 |
-
_, features = autoencoder(torch.Tensor(features_all[i]))
|
174 |
-
features_all_reduced[i] = features.detach().cpu().numpy()
|
175 |
-
|
176 |
-
# ### Plot the features
|
177 |
-
|
178 |
-
start = 0
|
179 |
-
end = len(cat)
|
180 |
-
all_values = set(range(start, end))
|
181 |
-
values_not_in_indexes_specz = all_values - set(indexes_specz)
|
182 |
-
indexes_nospecz = sorted(values_not_in_indexes_specz)
|
183 |
-
|
184 |
-
# +
|
185 |
-
import seaborn as sns
|
186 |
-
|
187 |
-
# Create subplots with three panels
|
188 |
-
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
189 |
-
|
190 |
-
# Set style for all subplots
|
191 |
-
sns.set_style("white")
|
192 |
-
|
193 |
-
# First subplot
|
194 |
-
sns.kdeplot(x=features_all_reduced[0, indexes_nospecz,0],
|
195 |
-
y=features_all_reduced[0, indexes_nospecz,1],
|
196 |
-
clip=(-150, 150),
|
197 |
-
ax=axs[0],
|
198 |
-
color='salmon')
|
199 |
-
sns.kdeplot(x=features_all_reduced[0, indexes_specz,0],
|
200 |
-
y=features_all_reduced[0, indexes_specz,1],
|
201 |
-
clip=(-150, 150),
|
202 |
-
ax=axs[0],
|
203 |
-
color='lightskyblue')
|
204 |
-
|
205 |
-
axs[0].set_xlim(-150, 150)
|
206 |
-
axs[0].set_ylim(-150, 150)
|
207 |
-
axs[0].set_title(r'Trained on $z_{\rm s}$')
|
208 |
-
|
209 |
-
# Second subplot
|
210 |
-
sns.kdeplot(x=features_all_reduced[1, indexes_nospecz, 0],
|
211 |
-
y=features_all_reduced[1, indexes_nospecz, 1],
|
212 |
-
clip=(-50, 50),
|
213 |
-
ax=axs[1],
|
214 |
-
color='salmon')
|
215 |
-
sns.kdeplot(x=features_all_reduced[1, indexes_specz, 0],
|
216 |
-
y=features_all_reduced[1, indexes_specz,1],
|
217 |
-
clip=(-50, 50),
|
218 |
-
ax=axs[1],
|
219 |
-
color='lightskyblue')
|
220 |
-
axs[1].set_xlim(-50, 50)
|
221 |
-
axs[1].set_ylim(-50, 50)
|
222 |
-
axs[1].set_title('Trained on L15')
|
223 |
-
|
224 |
-
# Third subplot
|
225 |
-
features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
|
226 |
-
sns.kdeplot(x=features_all_reduced_nospecz[:, 0],
|
227 |
-
y=features_all_reduced_nospecz[:, 1],
|
228 |
-
clip=(-1, 5),
|
229 |
-
ax=axs[2],
|
230 |
-
color='salmon',
|
231 |
-
label='Wide-field sample')
|
232 |
-
sns.kdeplot(x=features_all_reduced_specz[:, 0],
|
233 |
-
y=features_all_reduced_specz[:, 1],
|
234 |
-
clip=(-1, 5),
|
235 |
-
ax=axs[2],
|
236 |
-
color='lightskyblue',
|
237 |
-
label=r'$z_{\rm s}$ sample')
|
238 |
-
axs[2].set_xlim(-2, 5)
|
239 |
-
axs[2].set_ylim(-2, 5)
|
240 |
-
axs[2].set_title('TEMPS')
|
241 |
-
|
242 |
-
axs[0].set_xlabel('Feature 1')
|
243 |
-
axs[1].set_xlabel('Feature 1')
|
244 |
-
axs[2].set_xlabel('Feature 1')
|
245 |
-
axs[0].set_ylabel('Feature 2')
|
246 |
-
|
247 |
-
# Create custom legend with desired colors
|
248 |
-
legend_labels = ['Wide-field sample', r'$z_{\rm s}$ sample']
|
249 |
-
legend_handles = [plt.Line2D([0], [0], color='salmon', lw=2),
|
250 |
-
plt.Line2D([0], [0], color='lightskyblue', lw=2)]
|
251 |
-
axs[2].legend(legend_handles, legend_labels, loc='upper right', fontsize=16)
|
252 |
-
# Adjust layout
|
253 |
-
plt.tight_layout()
|
254 |
-
|
255 |
-
plt.savefig('Contourplot.pdf', bbox_inches='tight')
|
256 |
-
plt.show()
|
257 |
-
|
258 |
-
# -
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
np.savetxt('features.txt',features_all_reduced.reshape(3*164816, 2))
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
# +
|
279 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
280 |
-
|
281 |
-
fig, ax = plt.subplots(ncols = 3, figsize=(15,4), sharex=True, sharey=True)
|
282 |
-
colors = ['navy', 'goldenrod']
|
283 |
-
titles = [r'Training: $z_s$', r'Training: L15',r'Training: $z_s$ + DA']
|
284 |
-
x_min, x_max = -5,5
|
285 |
-
y_min, y_max = -5,5
|
286 |
-
x_grid, y_grid = np.meshgrid(np.linspace(x_min, x_max, 10), np.linspace(y_min, y_max, 10))
|
287 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
288 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
289 |
-
for il, lab in enumerate(['z','L15','DA']):
|
290 |
-
|
291 |
-
|
292 |
-
nn_features = EncoderPhotometry()
|
293 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
294 |
-
|
295 |
-
for it, target_type in enumerate(['L15','zs']):
|
296 |
-
if target_type=='zs':
|
297 |
-
cat_sub = photoz_archive._select_only_zspec(cat)
|
298 |
-
cat_sub = photoz_archive._clean_zspec_sample(cat_sub)
|
299 |
-
|
300 |
-
elif target_type=='L15':
|
301 |
-
cat_sub = photoz_archive._exclude_only_zspec(cat)
|
302 |
-
else:
|
303 |
-
assert False
|
304 |
-
|
305 |
-
cat_sub = photoz_archive._clean_photometry(cat_sub)
|
306 |
-
print(cat_sub.shape)
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
f, ferr = photoz_archive._extract_fluxes(cat_sub)
|
311 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
312 |
-
|
313 |
-
features = nn_features(torch.Tensor(col))
|
314 |
-
features = features.detach().cpu().numpy()
|
315 |
-
|
316 |
-
|
317 |
-
#xy = np.vstack([features[:1000,0], features[:1000,1]])
|
318 |
-
#zd = gaussian_kde(xy)(xy)
|
319 |
-
#ax[il].scatter(features[:1000,0], features[:1000,1],c=zd, s=3)
|
320 |
-
|
321 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
322 |
-
density_estimation = gaussian_kde(xy)
|
323 |
-
|
324 |
-
# Define grid for plotting density lines
|
325 |
-
|
326 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
327 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
328 |
-
|
329 |
-
# Plot contour lines representing density
|
330 |
-
ax[il].contour(x_grid, y_grid, density_grid, colors=colors[it], label = f'{target_type}')
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
ax[il].set_title(titles[il])
|
335 |
-
ax[il].set_xlim(-5,5)
|
336 |
-
ax[il].set_ylim(-5,5)
|
337 |
-
|
338 |
-
|
339 |
-
ax[0].set_ylabel('Feature 1', fontsize=14)
|
340 |
-
#plt.ylabel('Feature 2', fontsize=14)
|
341 |
-
|
342 |
-
#assert False
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
# -
|
350 |
-
|
351 |
-
H
|
352 |
-
|
353 |
-
H
|
354 |
-
|
355 |
-
xedges
|
356 |
-
|
357 |
-
yedges
|
358 |
-
|
359 |
-
# +
|
360 |
-
import matplotlib.colors as colors
|
361 |
-
from matplotlib import path
|
362 |
-
import numpy as np
|
363 |
-
from matplotlib import pyplot as plt
|
364 |
-
try:
|
365 |
-
from astropy.convolution import Gaussian2DKernel, convolve
|
366 |
-
astro_smooth = True
|
367 |
-
except ImportError as IE:
|
368 |
-
astro_smooth = False
|
369 |
-
|
370 |
-
np.random.seed(123)
|
371 |
-
#t = np.linspace(-5,1.2,1000)
|
372 |
-
x = features[:1000,0]
|
373 |
-
y = features[:1000,1]
|
374 |
-
|
375 |
-
H, xedges, yedges = np.histogram2d(x,y, bins=(10,10))
|
376 |
-
xmesh, ymesh = np.meshgrid(xedges[:-1], yedges[:-1])
|
377 |
-
|
378 |
-
# Smooth the contours (if astropy is installed)
|
379 |
-
if astro_smooth:
|
380 |
-
kernel = Gaussian2DKernel(x_stddev=1.)
|
381 |
-
H=convolve(H,kernel)
|
382 |
-
|
383 |
-
fig,ax = plt.subplots(1, figsize=(7,6))
|
384 |
-
clevels = ax.contour(xmesh,ymesh,H.T,lw=.9,cmap='winter')#,zorder=90)
|
385 |
-
ax.scatter(x,y,s=3)
|
386 |
-
#ax.set_xlim(-20,5)
|
387 |
-
#ax.set_ylim(-20,5)
|
388 |
-
|
389 |
-
# Identify points within contours
|
390 |
-
#p = clevels.collections[0].get_paths()
|
391 |
-
#inside = np.full_like(x,False,dtype=bool)
|
392 |
-
#for level in p:
|
393 |
-
# inside |= level.contains_points(zip(*(x,y)))
|
394 |
-
|
395 |
-
#ax.plot(x[~inside],y[~inside],'kx')
|
396 |
-
#plt.show(block=False)
|
397 |
-
# -
|
398 |
-
|
399 |
-
density_grid
|
400 |
-
|
401 |
-
features.shape, zd.shape
|
402 |
-
|
403 |
-
# + jupyter={"outputs_hidden": true}
|
404 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
405 |
-
zd = gaussian_kde(xy)(xy)
|
406 |
-
plt.scatter(features[:,0], features[:,1],c=zd)
|
407 |
-
|
408 |
-
|
409 |
-
# +
|
410 |
-
# Make the base corner plot
|
411 |
-
figure = corner.corner(features[:,:2], quantiles=[0.16, 0.84], show_titles=False, color ='crimson')
|
412 |
-
corner.corner(samples2, fig=fig)
|
413 |
-
ndim=2
|
414 |
-
# Extract the axes
|
415 |
-
axes = np.array(figure.axes).reshape((ndim, ndim))
|
416 |
-
|
417 |
-
|
418 |
-
for a in axes[np.triu_indices(ndim)]:
|
419 |
-
a.remove()
|
420 |
-
|
421 |
-
# +
|
422 |
-
import numpy as np
|
423 |
-
import matplotlib.pyplot as plt
|
424 |
-
from scipy.stats import gaussian_kde
|
425 |
-
|
426 |
-
# Assuming 'features' is your data array with shape (n_samples, 2)
|
427 |
-
|
428 |
-
# Calculate the density estimate
|
429 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
430 |
-
density_estimation = gaussian_kde(xy)
|
431 |
-
|
432 |
-
# Define grid for plotting density lines
|
433 |
-
|
434 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
435 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
436 |
-
|
437 |
-
# Plot contour lines representing density
|
438 |
-
plt.contour(x_grid, y_grid, density_grid, colors='black')
|
439 |
-
|
440 |
-
# Optionally, you can add a scatter plot on top of the density lines for better visualization
|
441 |
-
#plt.scatter(features[:,0], features[:,1], color='blue', alpha=0.5)
|
442 |
-
|
443 |
-
# Set labels and title
|
444 |
-
plt.xlabel('Feature 1')
|
445 |
-
plt.ylabel('Feature 2')
|
446 |
-
plt.title('Density Lines Plot')
|
447 |
-
|
448 |
-
# Show plot
|
449 |
-
plt.show()
|
450 |
-
|
451 |
-
# -
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
corner_plot = corner.corner(Arinyo_preds,
|
458 |
-
labels=[r'$b$', r'$\beta$', '$q_1$', '$k_{vav}$','$a_v$','$b_v$','$k_p$','$q_2$'],
|
459 |
-
truths=Arinyo_coeffs_central[test_snap],
|
460 |
-
truth_color='crimson')
|
461 |
-
|
462 |
-
import corner
|
463 |
-
figure = corner.corner(features, quantiles=[0.16, 0.5, 0.84], show_titles=False)
|
464 |
-
axes = np.array(fig.axes).reshape((ndim, ndim))
|
465 |
-
for a in axes[np.triu_indices(ndim)]:
|
466 |
-
a.remove()
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
# +
|
471 |
-
# My data
|
472 |
-
x = features[:,0]
|
473 |
-
y = features[:,1]
|
474 |
-
|
475 |
-
# Peform the kernel density estimate
|
476 |
-
k = stats.gaussian_kde(np.vstack([x, y]))
|
477 |
-
xi, yi = np.mgrid[-5:5,-5:5]
|
478 |
-
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
fig = plt.figure()
|
483 |
-
ax = fig.gca()
|
484 |
-
|
485 |
-
|
486 |
-
CS = ax.contour(xi, yi, zi.reshape(xi.shape), colors='crimson')
|
487 |
-
|
488 |
-
ax.set_xlim(-5, 5)
|
489 |
-
ax.set_ylim(-5, 5)
|
490 |
-
|
491 |
-
plt.show()
|
492 |
-
# -
|
493 |
-
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig2_NMAD.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# formats: ipynb,py:percent
|
5 |
-
# text_representation:
|
6 |
-
# extension: .py
|
7 |
-
# format_name: percent
|
8 |
-
# format_version: '1.3'
|
9 |
-
# jupytext_version: 1.14.5
|
10 |
-
# kernelspec:
|
11 |
-
# display_name: insight
|
12 |
-
# language: python
|
13 |
-
# name: insight
|
14 |
-
# ---
|
15 |
-
|
16 |
-
# %% [markdown]
|
17 |
-
# # FIGURE 2 IN THE PAPER
|
18 |
-
|
19 |
-
# %% [markdown]
|
20 |
-
# ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
|
21 |
-
|
22 |
-
# %% [markdown]
|
23 |
-
# ### LOAD PYTHON MODULES
|
24 |
-
|
25 |
-
# %%
|
26 |
-
# %load_ext autoreload
|
27 |
-
# %autoreload 2
|
28 |
-
|
29 |
-
# %%
|
30 |
-
import pandas as pd
|
31 |
-
import numpy as np
|
32 |
-
import os
|
33 |
-
from astropy.io import fits
|
34 |
-
from astropy.table import Table
|
35 |
-
import torch
|
36 |
-
|
37 |
-
# %%
|
38 |
-
#matplotlib settings
|
39 |
-
from matplotlib import rcParams
|
40 |
-
import matplotlib.pyplot as plt
|
41 |
-
rcParams["mathtext.fontset"] = "stix"
|
42 |
-
rcParams["font.family"] = "STIXGeneral"
|
43 |
-
|
44 |
-
|
45 |
-
# %%
|
46 |
-
#insight modules
|
47 |
-
import sys
|
48 |
-
sys.path.append('../temps')
|
49 |
-
|
50 |
-
from archive import archive
|
51 |
-
from utils import nmad
|
52 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
53 |
-
from temps import Temps_module
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
# %%
|
58 |
-
eval_methods=True
|
59 |
-
|
60 |
-
# %% [markdown]
|
61 |
-
# ### LOAD DATA
|
62 |
-
|
63 |
-
# %%
|
64 |
-
#define here the directory containing the photometric catalogues
|
65 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
66 |
-
modules_dir = '../data/models/'
|
67 |
-
|
68 |
-
# %%
|
69 |
-
#load catalogue and apply cuts
|
70 |
-
|
71 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
72 |
-
|
73 |
-
hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
|
74 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
75 |
-
cat = cat[cat['FLAG_PHOT']==0]
|
76 |
-
cat = cat[cat['mu_class_L07']==1]
|
77 |
-
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
78 |
-
cat = cat[cat['MAG_VIS']<25]
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
# %%
|
83 |
-
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
84 |
-
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
85 |
-
ID = cat['ID']
|
86 |
-
VISmag = cat['MAG_VIS']
|
87 |
-
zsflag = cat['reliable_S15']
|
88 |
-
|
89 |
-
# %%
|
90 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
91 |
-
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
92 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
93 |
-
|
94 |
-
# %% [markdown]
|
95 |
-
# ### EVALUATE USING TRAINED MODELS
|
96 |
-
|
97 |
-
# %%
|
98 |
-
if eval_methods:
|
99 |
-
|
100 |
-
dfs = {}
|
101 |
-
for il, lab in enumerate(['z','L15','DA']):
|
102 |
-
|
103 |
-
nn_features = EncoderPhotometry()
|
104 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
105 |
-
nn_z = MeasureZ(num_gauss=6)
|
106 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
107 |
-
|
108 |
-
temps = Temps_module(nn_features, nn_z)
|
109 |
-
|
110 |
-
z,zerr, zmode,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
|
111 |
-
return_pz=True)
|
112 |
-
# Create a DataFrame with the desired columns
|
113 |
-
df = pd.DataFrame(np.c_[ID, VISmag,z, zmode, flag, ztarget,zsflag,zerr, specz_or_photo],
|
114 |
-
columns=['ID','VISmag','z', 'zmode','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
|
115 |
-
|
116 |
-
# Calculate additional columns or operations if needed
|
117 |
-
df['zwerr'] = (df.zmode - df.ztarget) / (1 + df.ztarget)
|
118 |
-
|
119 |
-
# Drop any rows with NaN values
|
120 |
-
df = df.dropna()
|
121 |
-
|
122 |
-
# Assign the DataFrame to a key in the dictionary
|
123 |
-
dfs[lab] = df
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
# %%
|
128 |
-
dfs['z']['zwerr'] = (dfs['z'].z - dfs['z'].ztarget) / (1 + dfs['z'].ztarget)
|
129 |
-
dfs['L15']['zwerr'] = (dfs['L15'].z - dfs['L15'].ztarget) / (1 + dfs['L15'].ztarget)
|
130 |
-
dfs['DA']['zwerr'] = (dfs['DA'].z - dfs['DA'].ztarget) / (1 + dfs['DA'].ztarget)
|
131 |
-
|
132 |
-
# %% [markdown]
|
133 |
-
# ### LOAD CATALOGUES FROM PREVIOUS TRAINING
|
134 |
-
|
135 |
-
# %%
|
136 |
-
if not eval_methods:
|
137 |
-
dfs = {}
|
138 |
-
dfs['z'] = pd.read_csv(os.path.join(parent_dir, 'predictions_specztraining.csv'), header=0)
|
139 |
-
dfs['L15'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczL15training.csv'), header=0)
|
140 |
-
dfs['DA'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczDAtraining.csv'), header=0)
|
141 |
-
|
142 |
-
|
143 |
-
# %% [markdown]
|
144 |
-
# ### MAKE PLOT
|
145 |
-
|
146 |
-
# %%
|
147 |
-
plot_photoz(df_list,
|
148 |
-
nbins=8,
|
149 |
-
xvariable='VISmag',
|
150 |
-
metric='nmad',
|
151 |
-
type_bin='bin',
|
152 |
-
label_list = ['zs','zs+L15',r'TEMPS'],
|
153 |
-
save=False,
|
154 |
-
samp='L15'
|
155 |
-
)
|
156 |
-
|
157 |
-
# %%
|
158 |
-
plot_photoz(df_list,
|
159 |
-
nbins=8,
|
160 |
-
xvariable='VISmag',
|
161 |
-
metric='outliers',
|
162 |
-
type_bin='bin',
|
163 |
-
label_list = ['zs','zs+L15',r'TEMPS'],
|
164 |
-
save=False,
|
165 |
-
samp='L15'
|
166 |
-
)
|
167 |
-
|
168 |
-
# %%
|
169 |
-
|
170 |
-
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig3_PIT_CRPS.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# formats: ipynb,py:percent
|
5 |
-
# text_representation:
|
6 |
-
# extension: .py
|
7 |
-
# format_name: percent
|
8 |
-
# format_version: '1.3'
|
9 |
-
# jupytext_version: 1.14.5
|
10 |
-
# kernelspec:
|
11 |
-
# display_name: insight
|
12 |
-
# language: python
|
13 |
-
# name: insight
|
14 |
-
# ---
|
15 |
-
|
16 |
-
# %% [markdown]
|
17 |
-
# # FIGURE 3 IN THE PAPER
|
18 |
-
|
19 |
-
# %% [markdown]
|
20 |
-
# ## PIT AND CRPS FOR THE THREE METHODS
|
21 |
-
|
22 |
-
# %% [markdown]
|
23 |
-
# ### LOAD PYTHON MODULES
|
24 |
-
|
25 |
-
# %%
|
26 |
-
# %load_ext autoreload
|
27 |
-
# %autoreload 2
|
28 |
-
|
29 |
-
# %%
|
30 |
-
import pandas as pd
|
31 |
-
import numpy as np
|
32 |
-
import os
|
33 |
-
from astropy.io import fits
|
34 |
-
from astropy.table import Table
|
35 |
-
import torch
|
36 |
-
|
37 |
-
|
38 |
-
# %%
|
39 |
-
#matplotlib settings
|
40 |
-
from matplotlib import rcParams
|
41 |
-
import matplotlib.pyplot as plt
|
42 |
-
rcParams["mathtext.fontset"] = "stix"
|
43 |
-
rcParams["font.family"] = "STIXGeneral"
|
44 |
-
|
45 |
-
# %%
|
46 |
-
#insight modules
|
47 |
-
import sys
|
48 |
-
sys.path.append('../temps')
|
49 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
50 |
-
#from insight import Insight_module
|
51 |
-
from archive import archive
|
52 |
-
from utils import nmad
|
53 |
-
from plots import plot_PIT, plot_crps
|
54 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
55 |
-
from temps import Temps_module
|
56 |
-
|
57 |
-
|
58 |
-
# %% [markdown]
|
59 |
-
# ### LOAD DATA
|
60 |
-
|
61 |
-
# %%
|
62 |
-
photoz_archive = archive(path = parent_dir,
|
63 |
-
only_zspec=False,
|
64 |
-
flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
|
65 |
-
target_test='L15')
|
66 |
-
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
67 |
-
|
68 |
-
|
69 |
-
# %% [markdown]
|
70 |
-
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
|
71 |
-
|
72 |
-
# %% [markdown]
|
73 |
-
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
|
74 |
-
|
75 |
-
# %%
|
76 |
-
# Initialize an empty dictionary to store DataFrames
|
77 |
-
crps_dict = {}
|
78 |
-
pit_dict = {}
|
79 |
-
for il, lab in enumerate(['z','L15','DA']):
|
80 |
-
|
81 |
-
nn_features = EncoderPhotometry()
|
82 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
83 |
-
nn_z = MeasureZ(num_gauss=6)
|
84 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
85 |
-
|
86 |
-
temps = Temps_module(nn_features, nn_z)
|
87 |
-
|
88 |
-
|
89 |
-
pit_list = temps.pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
|
90 |
-
crps_list = temps.crps(input_data=torch.Tensor(f_test), target_data=specz_test)
|
91 |
-
|
92 |
-
|
93 |
-
# Assign the DataFrame to a key in the dictionary
|
94 |
-
crps_dict[lab] = crps_list
|
95 |
-
pit_dict[lab] = pit_list
|
96 |
-
|
97 |
-
|
98 |
-
# %%
|
99 |
-
plot_PIT(pit_dict['z'],
|
100 |
-
pit_dict['L15'],
|
101 |
-
pit_dict['DA'],
|
102 |
-
labels=[r'$z_{rm s}$', 'L15', 'TEMPS'],
|
103 |
-
sample='L15',
|
104 |
-
save=True)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
# %%
|
110 |
-
plot_crps(crps_dict['z'],
|
111 |
-
crps_dict['L15'],
|
112 |
-
crps_dict['DA'],
|
113 |
-
labels=[r'$z_{\rm s}$', 'L15', 'TEMPS'],
|
114 |
-
sample = 'L15',
|
115 |
-
save=True)
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig4_pz_examples.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# formats: ipynb,py:percent
|
5 |
-
# text_representation:
|
6 |
-
# extension: .py
|
7 |
-
# format_name: percent
|
8 |
-
# format_version: '1.3'
|
9 |
-
# jupytext_version: 1.14.5
|
10 |
-
# kernelspec:
|
11 |
-
# display_name: insight
|
12 |
-
# language: python
|
13 |
-
# name: insight
|
14 |
-
# ---
|
15 |
-
|
16 |
-
# %% [markdown]
|
17 |
-
# # FIGURE 4 IN THE PAPER
|
18 |
-
|
19 |
-
# %% [markdown]
|
20 |
-
# ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
|
21 |
-
|
22 |
-
# %% [markdown]
|
23 |
-
# ### LOAD PYTHON MODULES
|
24 |
-
|
25 |
-
# %%
|
26 |
-
# %load_ext autoreload
|
27 |
-
# %autoreload 2
|
28 |
-
|
29 |
-
# %%
|
30 |
-
import pandas as pd
|
31 |
-
import numpy as np
|
32 |
-
import os
|
33 |
-
from astropy.io import fits
|
34 |
-
from astropy.table import Table
|
35 |
-
import torch
|
36 |
-
|
37 |
-
# %%
|
38 |
-
#matplotlib settings
|
39 |
-
from matplotlib import rcParams
|
40 |
-
import matplotlib.pyplot as plt
|
41 |
-
rcParams["mathtext.fontset"] = "stix"
|
42 |
-
rcParams["font.family"] = "STIXGeneral"
|
43 |
-
|
44 |
-
# %%
|
45 |
-
#insight modules
|
46 |
-
import sys
|
47 |
-
sys.path.append('../temps')
|
48 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
49 |
-
#from insight import Insight_module
|
50 |
-
from archive import archive
|
51 |
-
from utils import nmad
|
52 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
53 |
-
from temps import Temps_module
|
54 |
-
|
55 |
-
|
56 |
-
# %% [markdown]
|
57 |
-
# ### LOAD DATA
|
58 |
-
|
59 |
-
# %%
|
60 |
-
#define here the directory containing the photometric catalogues
|
61 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
62 |
-
modules_dir = '../data/models/'
|
63 |
-
|
64 |
-
# %%
|
65 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
66 |
-
|
67 |
-
hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
|
68 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
69 |
-
cat = cat[cat['FLAG_PHOT']==0]
|
70 |
-
cat = cat[cat['mu_class_L07']==1]
|
71 |
-
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
72 |
-
cat = cat[cat['MAG_VIS']<25]
|
73 |
-
|
74 |
-
|
75 |
-
# %%
|
76 |
-
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
77 |
-
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
78 |
-
ID = cat['ID']
|
79 |
-
VISmag = cat['MAG_VIS']
|
80 |
-
zsflag = cat['reliable_S15']
|
81 |
-
|
82 |
-
# %%
|
83 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
84 |
-
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
85 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
86 |
-
|
87 |
-
# %% [markdown]
|
88 |
-
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
89 |
-
|
90 |
-
# %% [markdown]
|
91 |
-
# The notebook 'Tutorial_temps' gives an example of how to train and save models.
|
92 |
-
|
93 |
-
# %%
|
94 |
-
# Initialize an empty dictionary to store DataFrames
|
95 |
-
ii = np.random.randint(0,len(col),1)
|
96 |
-
pz_dict = {}
|
97 |
-
for il, lab in enumerate(['z','L15','DA']):
|
98 |
-
|
99 |
-
nn_features = EncoderPhotometry()
|
100 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
101 |
-
nn_z = MeasureZ(num_gauss=6)
|
102 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
103 |
-
|
104 |
-
temps = Temps_module(nn_features, nn_z)
|
105 |
-
|
106 |
-
|
107 |
-
z,zerr, pz, flag,_ = temps.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)
|
108 |
-
|
109 |
-
|
110 |
-
# Assign the DataFrame to a key in the dictionary
|
111 |
-
pz_dict[lab] = pz
|
112 |
-
|
113 |
-
|
114 |
-
# %%
|
115 |
-
cmap = plt.get_cmap('Dark2')
|
116 |
-
|
117 |
-
plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
|
118 |
-
plt.plot(np.linspace(0,5,1000),pz_dict['L15'][0],label='L15', color = cmap(1), ls =':')
|
119 |
-
plt.plot(np.linspace(0,5,1000),pz_dict['DA'][0],label='TEMPS', color = cmap(2), ls ='-')
|
120 |
-
plt.axvline(x=np.array(ztarget)[ii][0],ls='-.',color='black')
|
121 |
-
#plt.xlim(0,2)
|
122 |
-
plt.legend()
|
123 |
-
|
124 |
-
plt.xlabel(r'$z$', fontsize=14)
|
125 |
-
plt.ylabel('Probability', fontsize=14)
|
126 |
-
#plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
|
127 |
-
|
128 |
-
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig6_qualitycut.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# text_representation:
|
5 |
-
# extension: .py
|
6 |
-
# format_name: light
|
7 |
-
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.14.5
|
9 |
-
# kernelspec:
|
10 |
-
# display_name: insight
|
11 |
-
# language: python
|
12 |
-
# name: insight
|
13 |
-
# ---
|
14 |
-
|
15 |
-
# # FIGURE 6 IN THE PAPER
|
16 |
-
|
17 |
-
# ## QUALITY CUTS
|
18 |
-
|
19 |
-
# %load_ext autoreload
|
20 |
-
# %autoreload 2
|
21 |
-
|
22 |
-
import pandas as pd
|
23 |
-
import numpy as np
|
24 |
-
import os
|
25 |
-
import torch
|
26 |
-
from scipy import stats
|
27 |
-
|
28 |
-
#matplotlib settings
|
29 |
-
from matplotlib import rcParams
|
30 |
-
import matplotlib.pyplot as plt
|
31 |
-
rcParams["mathtext.fontset"] = "stix"
|
32 |
-
rcParams["font.family"] = "STIXGeneral"
|
33 |
-
|
34 |
-
#insight modules
|
35 |
-
import sys
|
36 |
-
sys.path.append('../temps')
|
37 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
38 |
-
#from insight import Insight_module
|
39 |
-
from archive import archive
|
40 |
-
from utils import nmad
|
41 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
42 |
-
from temps import Temps_module
|
43 |
-
|
44 |
-
|
45 |
-
# ### LOAD DATA (ONLY SPECZ)
|
46 |
-
|
47 |
-
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
49 |
-
modules_dir = '../data/models/'
|
50 |
-
|
51 |
-
photoz_archive = archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
|
52 |
-
f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
53 |
-
|
54 |
-
|
55 |
-
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
56 |
-
|
57 |
-
# +
|
58 |
-
# Initialize an empty dictionary to store DataFrames
|
59 |
-
dfs = {}
|
60 |
-
|
61 |
-
for il, lab in enumerate(['z','L15','DA']):
|
62 |
-
|
63 |
-
nn_features = EncoderPhotometry()
|
64 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
65 |
-
nn_z = MeasureZ(num_gauss=6)
|
66 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
67 |
-
|
68 |
-
temps = Temps_module(nn_features, nn_z)
|
69 |
-
|
70 |
-
z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(f_test_specz),
|
71 |
-
return_pz=True)
|
72 |
-
|
73 |
-
|
74 |
-
# Create a DataFrame with the desired columns
|
75 |
-
df = pd.DataFrame(np.c_[z, flag, odds, specz_test],
|
76 |
-
columns=['z','zflag', 'odds' ,'ztarget'])
|
77 |
-
|
78 |
-
# Calculate additional columns or operations if needed
|
79 |
-
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
80 |
-
|
81 |
-
# Drop any rows with NaN values
|
82 |
-
df = df.dropna()
|
83 |
-
|
84 |
-
# Assign the DataFrame to a key in the dictionary
|
85 |
-
dfs[lab] = df
|
86 |
-
|
87 |
-
# -
|
88 |
-
|
89 |
-
# ### STATISTICS BASED ON OUR QUALITY CUT
|
90 |
-
|
91 |
-
# +
|
92 |
-
bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.05))
|
93 |
-
scatter, eta, xlab, xmag, xzs, flagmean = [],[],[], [], [], []
|
94 |
-
|
95 |
-
for k in range(len(bin_edges)-1):
|
96 |
-
edge_min = bin_edges[k]
|
97 |
-
edge_max = bin_edges[k+1]
|
98 |
-
|
99 |
-
df_bin = df[(df.zflag > edge_min)]
|
100 |
-
|
101 |
-
|
102 |
-
xlab.append(np.round(len(df_bin)/len(df),2)*100)
|
103 |
-
xzs.append(0.5*(df_bin.ztarget.min()+df_bin.ztarget.max()))
|
104 |
-
flagmean.append(np.mean(df_bin.zflag))
|
105 |
-
scatter.append(nmad(df_bin.zwerr))
|
106 |
-
eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
|
107 |
-
|
108 |
-
|
109 |
-
# -
|
110 |
-
|
111 |
-
# ### STATISTICS BASED ON ODDS
|
112 |
-
|
113 |
-
# +
|
114 |
-
bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0,1.01,0.05))
|
115 |
-
scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
|
116 |
-
|
117 |
-
for k in range(len(bin_edges)-1):
|
118 |
-
edge_min = bin_edges[k]
|
119 |
-
edge_max = bin_edges[k+1]
|
120 |
-
|
121 |
-
df_bin = df[(df.odds > edge_min)]
|
122 |
-
|
123 |
-
|
124 |
-
xlab_odds.append(np.round(len(df_bin)/len(df),2)*100)
|
125 |
-
oddsmean.append(np.mean(df_bin.zflag))
|
126 |
-
scatter_odds.append(nmad(df_bin.zwerr))
|
127 |
-
eta_odds.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
|
128 |
-
|
129 |
-
|
130 |
-
# -
|
131 |
-
|
132 |
-
# ### PLOTS
|
133 |
-
|
134 |
-
# +
|
135 |
-
plt.plot(xlab_odds,scatter_odds, marker = '.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
|
136 |
-
plt.plot(xlab,scatter, marker = '.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
|
137 |
-
|
138 |
-
|
139 |
-
plt.ylabel(r'NMAD [$\Delta z\ /\ (1 + z_{\rm s})$]', fontsize=16)
|
140 |
-
plt.xlabel('Completeness', fontsize=16)
|
141 |
-
|
142 |
-
plt.yticks(fontsize=12)
|
143 |
-
plt.xticks(np.arange(5,101,10), fontsize=12)
|
144 |
-
plt.legend(fontsize=14)
|
145 |
-
|
146 |
-
plt.savefig('Flag_nmad_zspec.pdf', bbox_inches='tight')
|
147 |
-
plt.show()
|
148 |
-
|
149 |
-
# +
|
150 |
-
plt.plot(xlab_odds,eta_odds, marker='.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
|
151 |
-
plt.plot(xlab,eta, marker='.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
|
152 |
-
|
153 |
-
plt.yticks(fontsize=12)
|
154 |
-
plt.xticks(np.arange(5,101,10), fontsize=12)
|
155 |
-
plt.ylabel(r'$\eta$ [%]', fontsize=16)
|
156 |
-
plt.xlabel('Completeness', fontsize=16)
|
157 |
-
plt.legend()
|
158 |
-
|
159 |
-
plt.savefig('Flag_eta_zspec.pdf', bbox_inches='tight')
|
160 |
-
|
161 |
-
plt.show()
|
162 |
-
# -
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig7_colourspace.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# text_representation:
|
5 |
-
# extension: .py
|
6 |
-
# format_name: light
|
7 |
-
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.14.5
|
9 |
-
# kernelspec:
|
10 |
-
# display_name: insight
|
11 |
-
# language: python
|
12 |
-
# name: insight
|
13 |
-
# ---
|
14 |
-
|
15 |
-
# # FIGURE COLOURSPACE IN THE PAPER
|
16 |
-
|
17 |
-
# %load_ext autoreload
|
18 |
-
# %autoreload 2
|
19 |
-
|
20 |
-
import pandas as pd
|
21 |
-
import numpy as np
|
22 |
-
import os
|
23 |
-
from astropy.io import fits
|
24 |
-
from astropy.table import Table
|
25 |
-
import torch
|
26 |
-
|
27 |
-
#matplotlib settings
|
28 |
-
from matplotlib import rcParams
|
29 |
-
import matplotlib.pyplot as plt
|
30 |
-
rcParams["mathtext.fontset"] = "stix"
|
31 |
-
rcParams["font.family"] = "STIXGeneral"
|
32 |
-
|
33 |
-
# +
|
34 |
-
#insight modules
|
35 |
-
import sys
|
36 |
-
sys.path.append('../temps')
|
37 |
-
|
38 |
-
from archive import archive
|
39 |
-
from utils import nmad
|
40 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
41 |
-
from temps import Temps_module
|
42 |
-
from plots import plot_nz
|
43 |
-
|
44 |
-
|
45 |
-
# -
|
46 |
-
|
47 |
-
def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
|
48 |
-
"""
|
49 |
-
Estimate a Self-Organizing Map (SOM) visualization from a DataFrame.
|
50 |
-
|
51 |
-
Parameters:
|
52 |
-
- df (pandas.DataFrame): Input DataFrame containing data for SOM estimation.
|
53 |
-
- plot_arg (str, optional): Column name to be used for plotting. Default is 'z'.
|
54 |
-
- nx (int, optional): Number of cells along the X-axis. Default is 40.
|
55 |
-
- ny (int, optional): Number of cells along the Y-axis. Default is 40.
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
- som_data (numpy.ndarray): Estimated SOM visualization data.
|
59 |
-
"""
|
60 |
-
x_cells = np.arange(0, nx)
|
61 |
-
y_cells = np.arange(0, ny)
|
62 |
-
index_cell = np.arange(nx * ny)
|
63 |
-
cells = np.array(np.meshgrid(x_cells, y_cells)).T.reshape(-1, 2)
|
64 |
-
cells = pd.DataFrame(np.c_[cells[:, 0], cells[:, 1], index_cell], columns=['x_cell', 'y_cell', 'cell'])
|
65 |
-
|
66 |
-
if plot_arg == 'count':
|
67 |
-
som_vis = df.groupby('cell')['z'].count().reset_index().rename(columns={f'z': 'plot_som'})
|
68 |
-
else:
|
69 |
-
som_vis = df.groupby('cell')[f'{plot_arg}'].mean().reset_index().rename(columns={f'{plot_arg}': 'plot_som'})
|
70 |
-
|
71 |
-
som_data = som_vis.merge(cells, on='cell')
|
72 |
-
som_data = som_data.pivot(index='x_cell', columns='y_cell', values='plot_som')
|
73 |
-
|
74 |
-
return som_data
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
|
79 |
-
"""
|
80 |
-
Plot the Self-Organizing Map (SOM) data.
|
81 |
-
|
82 |
-
Parameters:
|
83 |
-
- som_data (numpy.ndarray): The SOM data to be visualized.
|
84 |
-
- plot_arg (str, optional): The column name to be plotted. Default is 'z'.
|
85 |
-
- vmin (float, optional): Minimum value for color scaling. Default is 0.
|
86 |
-
- vmax (float, optional): Maximum value for color scaling. Default is 1.
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
None
|
90 |
-
"""
|
91 |
-
plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap='viridis') # Choose an appropriate colormap
|
92 |
-
plt.colorbar(label=f'{plot_arg}') # Add a colorbar with a label
|
93 |
-
plt.xlabel(r'$x$ [pixel]', fontsize=14) # Add an appropriate X-axis label
|
94 |
-
plt.ylabel(r'$y$ [pixel]', fontsize=14) # Add an appropriate Y-axis label
|
95 |
-
plt.show()
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
# ### LOAD DATA
|
100 |
-
|
101 |
-
# +
|
102 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
103 |
-
|
104 |
-
hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
|
105 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
106 |
-
cat = cat[cat['FLAG_PHOT']==0]
|
107 |
-
cat = cat[cat['mu_class_L07']==1]
|
108 |
-
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
109 |
-
cat = cat[cat['MAG_VIS']<25]
|
110 |
-
|
111 |
-
# -
|
112 |
-
|
113 |
-
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
114 |
-
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
115 |
-
ID = cat['ID']
|
116 |
-
VISmag = cat['MAG_VIS']
|
117 |
-
zsflag = cat['reliable_S15']
|
118 |
-
|
119 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
120 |
-
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
121 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
122 |
-
|
123 |
-
# +
|
124 |
-
dfs = {}
|
125 |
-
|
126 |
-
for il, lab in enumerate(['z','L15','DA']):
|
127 |
-
|
128 |
-
nn_features = EncoderPhotometry()
|
129 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
130 |
-
nn_z = MeasureZ(num_gauss=6)
|
131 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
132 |
-
|
133 |
-
temps = Temps_module(nn_features, nn_z)
|
134 |
-
|
135 |
-
z,zerr ,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
|
136 |
-
return_pz=True)
|
137 |
-
# Create a DataFrame with the desired columns
|
138 |
-
df = pd.DataFrame(np.c_[ID, VISmag,z, flag, ztarget,zsflag,zerr, specz_or_photo],
|
139 |
-
columns=['ID','VISmag','z','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
|
140 |
-
|
141 |
-
# Calculate additional columns or operations if needed
|
142 |
-
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
143 |
-
|
144 |
-
# Drop any rows with NaN values
|
145 |
-
df = df.dropna()
|
146 |
-
|
147 |
-
# Assign the DataFrame to a key in the dictionary
|
148 |
-
dfs[lab] = df
|
149 |
-
|
150 |
-
# -
|
151 |
-
|
152 |
-
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
|
153 |
-
|
154 |
-
#define here the directory containing the photometric catalogues
|
155 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
156 |
-
modules_dir = '../data/models/'
|
157 |
-
|
158 |
-
df_z = dfs['z']
|
159 |
-
df_z_DA = dfs['DA']
|
160 |
-
|
161 |
-
# ##### LOAD TRAIN SOM ON TRAINING DATA
|
162 |
-
|
163 |
-
df_som = pd.read_csv(os.path.join(parent_dir,'som_dataframe.csv'), header = 0, sep =',')
|
164 |
-
df_z = df_z.merge(df_som, on = 'ID')
|
165 |
-
df_z_DA = df_z_DA.merge(df_som, on = 'ID')
|
166 |
-
|
167 |
-
# ##### APPLY CUTS FOR DIFFERENT SAMPLES
|
168 |
-
|
169 |
-
df_zspec = df_z[(df_z.S15_L15_flag==0) & (df_z.zsflag==1)]
|
170 |
-
df_l15 = df_z[(df_z.ztarget>0)]
|
171 |
-
df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
|
172 |
-
|
173 |
-
df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
|
174 |
-
df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.zflag>0.033]
|
175 |
-
|
176 |
-
df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
|
177 |
-
df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.zflag>0.018]
|
178 |
-
|
179 |
-
# ## MAKE SOM PLOT
|
180 |
-
|
181 |
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
182 |
-
|
183 |
-
# +
|
184 |
-
fig, axs = plt.subplots(6, 4, figsize=(13, 15), sharex=True, sharey=True, gridspec_kw={'hspace': 0.05, 'wspace': 0.06})
|
185 |
-
|
186 |
-
# Plot in the top row (axs[0, i])
|
187 |
-
#top row, spectroscopic sample
|
188 |
-
columns = ['ztarget','z','zwerr','count']
|
189 |
-
titles = [r'$z_{true}$',r'$z$',r'$z_{\rm error}$','Counts']
|
190 |
-
limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
|
191 |
-
for ii in range(4):
|
192 |
-
som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
|
193 |
-
im = axs[0,ii].imshow(som_data, vmin=limits[ii][0], vmax=limits[ii][1], cmap='viridis') # Choose an appropriate colormap
|
194 |
-
axs[0, ii].set_title(f'{titles[ii]}', fontsize=18)
|
195 |
-
|
196 |
-
if ii==0:
|
197 |
-
axs[0, 0].set_ylabel(r'$y$', fontsize=14)
|
198 |
-
elif ii==1:
|
199 |
-
cbar_ax = fig.add_axes([0.49, 0.11, 0.01, 0.77])
|
200 |
-
fig.colorbar(im, cax=cbar_ax)
|
201 |
-
elif ii==2:
|
202 |
-
cbar_ax = fig.add_axes([0.685, 0.11, 0.01, 0.77])
|
203 |
-
fig.colorbar(im, cax=cbar_ax)
|
204 |
-
elif ii==3:
|
205 |
-
cbar_ax = fig.add_axes([0.885, 0.11, 0.01, 0.77])
|
206 |
-
fig.colorbar(im, cax=cbar_ax)
|
207 |
-
|
208 |
-
for jj in range(4):
|
209 |
-
som_data = estimate_som_map(df_l15, plot_arg=columns[jj], nx=40, ny=40)
|
210 |
-
im = axs[1,jj].imshow(som_data, vmin=limits[jj][0], vmax=limits[jj][1], cmap='viridis') # Choose an appropriate colormap
|
211 |
-
#axs[1, jj].set_title(f'{titles[jj]}', fontsize=14)
|
212 |
-
#axs[1, jj].set_xlabel(r'$x$', fontsize=14)
|
213 |
-
|
214 |
-
|
215 |
-
for kk in range(4):
|
216 |
-
som_data = estimate_som_map(df_l15_DA, plot_arg=columns[kk], nx=40, ny=40)
|
217 |
-
im = axs[2,kk].imshow(som_data, vmin=limits[kk][0], vmax=limits[kk][1], cmap='viridis') # Choose an appropriate colormap
|
218 |
-
#axs[2, kk].set_title(f'{titles[kk]}', fontsize=14)
|
219 |
-
#axs[2, kk].set_xlabel(r'$x$', fontsize=14)
|
220 |
-
|
221 |
-
for rr in range(4):
|
222 |
-
som_data = estimate_som_map(df_l15_euclid_da, plot_arg=columns[rr], nx=40, ny=40)
|
223 |
-
im = axs[3,rr].imshow(som_data, vmin=limits[rr][0], vmax=limits[rr][1], cmap='viridis') # Choose an appropriate colormap
|
224 |
-
#axs[3, rr].set_title(f'{titles[rr]}', fontsize=14)
|
225 |
-
#axs[3, rr].set_xlabel(r'$x$', fontsize=14)
|
226 |
-
|
227 |
-
for ll in range(4):
|
228 |
-
som_data = estimate_som_map(df_l15_euclid_cut, plot_arg=columns[ll], nx=40, ny=40)
|
229 |
-
im = axs[4,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
|
230 |
-
#axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
|
231 |
-
axs[4, ll].set_xlabel(r'$x$', fontsize=14)
|
232 |
-
|
233 |
-
for ll in range(4):
|
234 |
-
som_data = estimate_som_map(df_l15_euclid_cut_da, plot_arg=columns[ll], nx=40, ny=40)
|
235 |
-
im = axs[5,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
|
236 |
-
#axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
|
237 |
-
axs[5, ll].set_xlabel(r'$x$', fontsize=14)
|
238 |
-
|
239 |
-
|
240 |
-
axs[0, 0].set_ylabel(r'$y$', fontsize=14)
|
241 |
-
axs[1, 0].set_ylabel(r'$y$', fontsize=14)
|
242 |
-
axs[2, 0].set_ylabel(r'$y$', fontsize=14)
|
243 |
-
axs[3, 0].set_ylabel(r'$y$', fontsize=14)
|
244 |
-
axs[4, 0].set_ylabel(r'$y$', fontsize=14)
|
245 |
-
axs[5, 0].set_ylabel(r'$y$', fontsize=14)
|
246 |
-
|
247 |
-
|
248 |
-
fig.text(0.09, 0.815, r'$z_{\rm s}$ sample', va='center', rotation='vertical', fontsize=16)
|
249 |
-
fig.text(0.09, 0.69, r'L15 sample', va='center', rotation='vertical', fontsize=16)
|
250 |
-
fig.text(0.09, 0.56, r'L15 sample + DA', va='center', rotation='vertical', fontsize=14)
|
251 |
-
fig.text(0.09, 0.44, r'$Euclid$ sample + DA', va='center', rotation='vertical', fontsize=14)
|
252 |
-
fig.text(0.09, 0.3, r'$Euclid$ sample + QC', va='center', rotation='vertical', fontsize=14)
|
253 |
-
|
254 |
-
fig.text(0.09, 0.17, r'$Euclid$ sample + DA + QC', va='center', rotation='vertical', fontsize=13)
|
255 |
-
|
256 |
-
|
257 |
-
plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
|
258 |
-
|
259 |
-
# -
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Table_metrics.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# text_representation:
|
5 |
-
# extension: .py
|
6 |
-
# format_name: light
|
7 |
-
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.14.5
|
9 |
-
# kernelspec:
|
10 |
-
# display_name: insight
|
11 |
-
# language: python
|
12 |
-
# name: insight
|
13 |
-
# ---
|
14 |
-
|
15 |
-
# # TABLE METRICS
|
16 |
-
|
17 |
-
# %load_ext autoreload
|
18 |
-
# %autoreload 2
|
19 |
-
|
20 |
-
import pandas as pd
|
21 |
-
import numpy as np
|
22 |
-
import os
|
23 |
-
import torch
|
24 |
-
from scipy import stats
|
25 |
-
from astropy.io import fits
|
26 |
-
from astropy.table import Table
|
27 |
-
|
28 |
-
#matplotlib settings
|
29 |
-
from matplotlib import rcParams
|
30 |
-
import matplotlib.pyplot as plt
|
31 |
-
rcParams["mathtext.fontset"] = "stix"
|
32 |
-
rcParams["font.family"] = "STIXGeneral"
|
33 |
-
|
34 |
-
#insight modules
|
35 |
-
import sys
|
36 |
-
sys.path.append('../temps')
|
37 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
38 |
-
#from insight import Insight_module
|
39 |
-
from archive import archive
|
40 |
-
from utils import nmad, select_cut
|
41 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
42 |
-
from temps import Temps_module
|
43 |
-
|
44 |
-
|
45 |
-
# ## LOAD DATA
|
46 |
-
|
47 |
-
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
49 |
-
modules_dir = '../data/models/'
|
50 |
-
|
51 |
-
# +
|
52 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
53 |
-
|
54 |
-
hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
|
55 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
56 |
-
cat = cat[cat['FLAG_PHOT']==0]
|
57 |
-
cat = cat[cat['mu_class_L07']==1]
|
58 |
-
|
59 |
-
cat['SNR_VIS'] = cat.FLUX_VIS / cat.FLUXERR_VIS
|
60 |
-
# -
|
61 |
-
|
62 |
-
cat = cat[cat.SNR_VIS>10]
|
63 |
-
|
64 |
-
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
65 |
-
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
66 |
-
ID = cat['ID']
|
67 |
-
VISmag = cat['MAG_VIS']
|
68 |
-
zsflag = cat['reliable_S15']
|
69 |
-
|
70 |
-
cat['ztarget']=ztarget
|
71 |
-
cat['specz_or_photo']=specz_or_photo
|
72 |
-
|
73 |
-
cat = cat[cat.ztarget>0]
|
74 |
-
|
75 |
-
# ### EXTRACT PHOTOMETRY
|
76 |
-
|
77 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
78 |
-
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
79 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
80 |
-
|
81 |
-
# ### MEASURE CATALOGUE
|
82 |
-
|
83 |
-
# +
|
84 |
-
# Initialize an empty dictionary to store DataFrames
|
85 |
-
lab='DA'
|
86 |
-
nn_features = EncoderPhotometry()
|
87 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
88 |
-
nn_z = MeasureZ(num_gauss=6)
|
89 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
90 |
-
|
91 |
-
temps = Temps_module(nn_features, nn_z)
|
92 |
-
|
93 |
-
z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
|
94 |
-
return_pz=True)
|
95 |
-
|
96 |
-
|
97 |
-
# Create a DataFrame with the desired columns
|
98 |
-
df = pd.DataFrame(np.c_[z, flag, odds, cat.ztarget, cat.reliable_S15, cat.specz_or_photo],
|
99 |
-
columns=['z','zflag', 'odds' ,'ztarget','reliable_S15', 'specz_or_photo'])
|
100 |
-
|
101 |
-
# Calculate additional columns or operations if needed
|
102 |
-
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
103 |
-
|
104 |
-
# Drop any rows with NaN values
|
105 |
-
df = df.dropna()
|
106 |
-
|
107 |
-
|
108 |
-
# -
|
109 |
-
|
110 |
-
# ### SPECZ SAMPLE
|
111 |
-
|
112 |
-
df_specz = df[(df.reliable_S15==1)&(df.specz_or_photo==0)]
|
113 |
-
|
114 |
-
# +
|
115 |
-
df_selected, cut, dfcuts = select_cut(df_specz,
|
116 |
-
completenss_lim=None,
|
117 |
-
nmad_lim=0.055,
|
118 |
-
outliers_lim=None,
|
119 |
-
return_df=True)
|
120 |
-
|
121 |
-
|
122 |
-
# -
|
123 |
-
|
124 |
-
print(dfcuts.to_latex(float_format="%.3f",
|
125 |
-
columns=['Nobj','completeness', 'nmad', 'eta'],
|
126 |
-
index=False
|
127 |
-
))
|
128 |
-
|
129 |
-
# ### EUCLID SAMPLE
|
130 |
-
|
131 |
-
df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
|
132 |
-
|
133 |
-
# +
|
134 |
-
df_selected, cut, dfcuts = select_cut(df_euclid,
|
135 |
-
completenss_lim=None,
|
136 |
-
nmad_lim=0.055,
|
137 |
-
outliers_lim=None,
|
138 |
-
return_df=True)
|
139 |
-
|
140 |
-
|
141 |
-
# -
|
142 |
-
|
143 |
-
print(dfcuts.to_latex(float_format="%.3f",
|
144 |
-
columns=['Nobj','completeness', 'nmad', 'eta'],
|
145 |
-
index=False
|
146 |
-
))
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temps/archive.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
4 |
-
import os
|
5 |
from astropy.table import Table
|
6 |
from scipy.spatial import KDTree
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
|
10 |
-
from matplotlib import rcParams
|
11 |
rcParams["mathtext.fontset"] = "stix"
|
12 |
rcParams["font.family"] = "STIXGeneral"
|
13 |
|
14 |
-
|
15 |
-
class archive():
|
16 |
def __init__(self, path,
|
17 |
aperture=2,
|
18 |
drop_stars=True,
|
@@ -21,31 +21,42 @@ class archive():
|
|
21 |
extinction_corr=True,
|
22 |
only_zspec=True,
|
23 |
all_apertures=False,
|
24 |
-
target_test='specz', flags_kept=[3,3.1,3.4,3.5,4]):
|
|
|
25 |
|
|
|
26 |
self.aperture = aperture
|
27 |
-
self.all_apertures=all_apertures
|
28 |
-
self.flags_kept=flags_kept
|
29 |
|
|
|
|
|
30 |
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
|
|
42 |
|
43 |
|
44 |
if drop_stars==True:
|
|
|
45 |
cat = cat[cat.mu_class_L07==1]
|
46 |
cat_test = cat_test[cat_test.mu_class_L07==1]
|
47 |
|
48 |
if clean_photometry==True:
|
|
|
49 |
cat = self._clean_photometry(cat)
|
50 |
cat_test = self._clean_photometry(cat_test)
|
51 |
|
@@ -216,9 +227,11 @@ class archive():
|
|
216 |
|
217 |
|
218 |
if only_zspec:
|
|
|
219 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
|
220 |
catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
|
221 |
else:
|
|
|
222 |
catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
|
223 |
|
224 |
|
@@ -233,9 +246,11 @@ class archive():
|
|
233 |
|
234 |
|
235 |
if extinction_corr==True:
|
|
|
236 |
f = self._correct_extinction(catalogue,f)
|
237 |
|
238 |
if convert_colors==True:
|
|
|
239 |
col, colerr = self._to_colors(f, ferr)
|
240 |
col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
|
241 |
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
|
|
4 |
from astropy.table import Table
|
5 |
from scipy.spatial import KDTree
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from matplotlib import rcParams
|
8 |
+
from pathlib import Path
|
9 |
+
from loguru import logger
|
10 |
|
|
|
11 |
|
|
|
12 |
rcParams["mathtext.fontset"] = "stix"
|
13 |
rcParams["font.family"] = "STIXGeneral"
|
14 |
|
15 |
+
class Archive:
|
|
|
16 |
def __init__(self, path,
|
17 |
aperture=2,
|
18 |
drop_stars=True,
|
|
|
21 |
extinction_corr=True,
|
22 |
only_zspec=True,
|
23 |
all_apertures=False,
|
24 |
+
target_test='specz', flags_kept=[3, 3.1, 3.4, 3.5, 4]):
|
25 |
+
|
26 |
|
27 |
+
logger.info("Starting archive")
|
28 |
self.aperture = aperture
|
29 |
+
self.all_apertures = all_apertures
|
30 |
+
self.flags_kept = flags_kept
|
31 |
|
32 |
+
filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
33 |
+
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
34 |
|
35 |
+
# Use Path for file handling
|
36 |
+
path_calib = Path(path) / filename_calib
|
37 |
+
path_valid = Path(path) / filename_valid
|
38 |
|
39 |
+
# Open the calibration FITS file
|
40 |
+
with fits.open(path_calib) as hdu_list:
|
41 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
42 |
+
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
43 |
|
44 |
+
# Open the validation FITS file
|
45 |
+
with fits.open(path_valid) as hdu_list:
|
46 |
+
cat_test = Table(hdu_list[1].data).to_pandas()
|
|
|
47 |
|
48 |
+
# Store the catalogs for later use
|
49 |
+
self.cat = cat
|
50 |
+
self.cat_test = cat_test
|
51 |
|
52 |
|
53 |
if drop_stars==True:
|
54 |
+
logger.info("dropping stars...")
|
55 |
cat = cat[cat.mu_class_L07==1]
|
56 |
cat_test = cat_test[cat_test.mu_class_L07==1]
|
57 |
|
58 |
if clean_photometry==True:
|
59 |
+
logger.info("cleaning stars...")
|
60 |
cat = self._clean_photometry(cat)
|
61 |
cat_test = self._clean_photometry(cat_test)
|
62 |
|
|
|
227 |
|
228 |
|
229 |
if only_zspec:
|
230 |
+
logger.info("Selecting only galaxies with spectroscopic redshift")
|
231 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
|
232 |
catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
|
233 |
else:
|
234 |
+
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
|
235 |
catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
|
236 |
|
237 |
|
|
|
246 |
|
247 |
|
248 |
if extinction_corr==True:
|
249 |
+
logger.info("Correcting MW extinction")
|
250 |
f = self._correct_extinction(catalogue,f)
|
251 |
|
252 |
if convert_colors==True:
|
253 |
+
logger.info("Converting to colors")
|
254 |
col, colerr = self._to_colors(f, ferr)
|
255 |
col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
|
256 |
|
temps/plots.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
-
from utils import nmad
|
5 |
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
@@ -181,68 +181,7 @@ def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', lab
|
|
181 |
# Show the plot
|
182 |
plt.show()
|
183 |
|
184 |
-
|
185 |
-
import numpy as np
|
186 |
-
import matplotlib.pyplot as plt
|
187 |
-
from scipy import stats
|
188 |
-
|
189 |
-
def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
|
190 |
-
#plot properties
|
191 |
-
plt.rcParams['font.family'] = 'serif'
|
192 |
-
plt.rcParams['font.size'] = 12
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
|
198 |
-
print(bin_edges)
|
199 |
-
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
200 |
-
plt.figure(figsize=(6, 5))
|
201 |
-
ls = ['--',':','-']
|
202 |
-
|
203 |
-
for i, df in enumerate(df_list):
|
204 |
-
ydata, xlab = [], []
|
205 |
-
|
206 |
-
for k in range(len(bin_edges)-1):
|
207 |
-
edge_min = bin_edges[k]
|
208 |
-
edge_max = bin_edges[k+1]
|
209 |
-
|
210 |
-
mean_mag = (edge_max + edge_min) / 2
|
211 |
-
|
212 |
-
if type_bin == 'bin':
|
213 |
-
df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
|
214 |
-
elif type_bin == 'cum':
|
215 |
-
df_plot = df[(df[xvariable] < edge_max)]
|
216 |
-
else:
|
217 |
-
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
218 |
-
|
219 |
-
xlab.append(mean_mag)
|
220 |
-
if metric == 'sig68':
|
221 |
-
ydata.append(sigma68(df_plot.zwerr))
|
222 |
-
elif metric == 'bias':
|
223 |
-
ydata.append(np.mean(df_plot.zwerr))
|
224 |
-
elif metric == 'nmad':
|
225 |
-
ydata.append(nmad(df_plot.zwerr))
|
226 |
-
elif metric == 'outliers':
|
227 |
-
ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
|
228 |
-
|
229 |
-
print(ydata)
|
230 |
-
color = cmap(i) # Get a different color for each dataframe
|
231 |
-
plt.plot(xlab, ydata,marker='.', lw=1, label=f'{label_list[i]}', color=color, ls=ls[i])
|
232 |
-
|
233 |
-
if xvariable == 'VISmag':
|
234 |
-
xvariable_lab = 'VIS'
|
235 |
-
|
236 |
-
|
237 |
|
238 |
-
plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
|
239 |
-
plt.xlabel(f'{xvariable_lab}', fontsize=16)
|
240 |
-
plt.grid(False)
|
241 |
-
plt.legend()
|
242 |
-
|
243 |
-
if save==True:
|
244 |
-
plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
|
245 |
-
plt.show()
|
246 |
|
247 |
|
248 |
def plot_nz(df_list,
|
@@ -336,3 +275,43 @@ def plot_crps(crps_list_1, crps_list_2 = None, crps_list_3=None, labels=None, s
|
|
336 |
# Show the plot
|
337 |
plt.show()
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
from temps.utils import nmad
|
5 |
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
|
|
181 |
# Show the plot
|
182 |
plt.show()
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
|
187 |
def plot_nz(df_list,
|
|
|
275 |
# Show the plot
|
276 |
plt.show()
|
277 |
|
278 |
+
|
279 |
+
|
280 |
+
def plot_nz(df, bins=np.arange(0,5,0.2)):
|
281 |
+
kwargs=dict( bins=bins,alpha=0.5)
|
282 |
+
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
|
283 |
+
counts, _, =np.histogram(df.z.values, bins=bins)
|
284 |
+
|
285 |
+
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
|
286 |
+
|
287 |
+
#plt.legend(fontsize=14)
|
288 |
+
plt.xlabel(r'Redshift', fontsize=14)
|
289 |
+
plt.ylabel(r'Counts', fontsize=14)
|
290 |
+
plt.yscale('log')
|
291 |
+
|
292 |
+
plt.show()
|
293 |
+
|
294 |
+
return
|
295 |
+
|
296 |
+
|
297 |
+
def plot_scatter(df, sample='specz', save=True):
|
298 |
+
# Calculate the point density
|
299 |
+
xy = np.vstack([df.zs.values,df.z.values])
|
300 |
+
zd = gaussian_kde(xy)(xy)
|
301 |
+
|
302 |
+
fig, ax = plt.subplots()
|
303 |
+
plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
|
304 |
+
plt.xlim(0,5)
|
305 |
+
plt.ylim(0,5)
|
306 |
+
|
307 |
+
plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
|
308 |
+
plt.ylabel('$z$', fontsize = 14)
|
309 |
+
|
310 |
+
plt.xticks(fontsize = 12)
|
311 |
+
plt.yticks(fontsize = 12)
|
312 |
+
|
313 |
+
if save==True:
|
314 |
+
plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
|
315 |
+
|
316 |
+
plt.show()
|
317 |
+
|
temps/temps.py
CHANGED
@@ -1,249 +1,267 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import DataLoader, dataset, TensorDataset
|
3 |
from torch import nn, optim
|
|
|
4 |
from torch.optim import lr_scheduler
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
from astropy.io import fits
|
8 |
-
import os
|
9 |
-
from astropy.table import Table
|
10 |
-
from scipy.spatial import KDTree
|
11 |
-
from scipy.special import erf
|
12 |
from scipy.stats import norm
|
13 |
-
import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
self
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
input_data = torch.Tensor(input_data)
|
36 |
target_data = torch.Tensor(target_data)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
return loader_train, loader_val
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
log_prob =
|
55 |
-
log_prob = torch.logsumexp(log_prob, 1)
|
56 |
loss = -log_prob.mean()
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
kl_loss = nn.KLDivLoss(reduction="batchmean",log_target=True)
|
62 |
loss = kl_loss(f1, f2)
|
63 |
-
|
64 |
-
#print('f1',f1)
|
65 |
-
#print('f2',f2)
|
66 |
-
|
67 |
-
return loss
|
68 |
|
69 |
-
def _to_numpy(self,x):
|
|
|
70 |
return x.detach().cpu().numpy()
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
def train(self,input_data,
|
75 |
-
input_data_DA,
|
76 |
-
target_data,
|
77 |
-
nepochs=10,
|
78 |
-
step_size = 100,
|
79 |
-
val_fraction=0.1,
|
80 |
-
lr=1e-3,
|
81 |
-
weight_decay=0):
|
82 |
-
self.modelZ = self.modelZ.train()
|
83 |
-
self.modelF = self.modelF.train()
|
84 |
-
|
85 |
-
loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, val_fraction=0.1)
|
86 |
-
optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
|
87 |
-
optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
|
88 |
-
|
89 |
-
schedulerZ = torch.optim.lr_scheduler.StepLR(optimizerZ, step_size=step_size, gamma =0.1)
|
90 |
-
schedulerF = torch.optim.lr_scheduler.StepLR(optimizerF, step_size=step_size, gamma =0.1)
|
91 |
-
|
92 |
-
self.modelZ = self.modelZ.to(self.device)
|
93 |
-
self.modelF = self.modelF.to(self.device)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if self.da:
|
105 |
input_data_da = input_data_da.to(self.device)
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
optimizerZ.zero_grad()
|
110 |
|
111 |
-
features = self.
|
112 |
-
if self.da
|
113 |
-
features_DA = self.modelF(input_data_da)
|
114 |
|
115 |
-
mu, logsig, logmix_coeff = self.
|
116 |
-
logsig = torch.clamp(logsig
|
117 |
sig = torch.exp(logsig)
|
118 |
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
|
|
|
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
_loss_train.append(lossZ.item())
|
130 |
-
|
131 |
-
loss.backward()
|
132 |
-
optimizerF.step()
|
133 |
-
optimizerZ.step()
|
134 |
-
|
135 |
-
schedulerF.step()
|
136 |
-
schedulerZ.step()
|
137 |
-
|
138 |
-
self.loss_train.append(np.mean(_loss_train))
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
141 |
|
|
|
|
|
|
|
|
|
142 |
input_data = input_data.to(self.device)
|
143 |
target_data = target_data.to(self.device)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
logsig = torch.clamp(logsig,-6,2)
|
150 |
sig = torch.exp(logsig)
|
151 |
|
152 |
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
|
153 |
_loss_validation.append(loss_val.item())
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
if self.verbose:
|
159 |
-
|
160 |
-
print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
|
161 |
-
|
162 |
|
163 |
def get_features(self, input_data):
|
164 |
-
|
165 |
-
self.
|
166 |
-
|
167 |
input_data = input_data.to(self.device)
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
return features.detach().cpu().numpy()
|
172 |
-
|
173 |
|
174 |
-
def get_pz(self,input_data, return_pz=True, return_flag=True,
|
175 |
-
|
176 |
-
|
177 |
-
self.
|
178 |
-
self.
|
179 |
|
180 |
input_data = input_data.to(self.device)
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
mu, logsig, logmix_coeff = self.modelZ(features)
|
185 |
-
logsig = torch.clamp(logsig,-6,2)
|
186 |
sig = torch.exp(logsig)
|
187 |
|
188 |
mix_coeff = torch.exp(logmix_coeff)
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - mu.mean(1)[:,None])**2).sum(1))
|
192 |
-
|
193 |
-
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
194 |
-
|
195 |
-
|
196 |
-
if return_pz==True:
|
197 |
-
zgrid = np.linspace(0, 5, 1000)
|
198 |
-
pdf_mixture = np.zeros(shape=(len(input_data), len(zgrid)))
|
199 |
-
for ii in range(len(input_data)):
|
200 |
-
for i in range(self.ngaussians):
|
201 |
-
pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(zgrid, mu[ii,i], sig[ii,i])
|
202 |
-
if return_flag==True:
|
203 |
-
#narrow peak
|
204 |
-
pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
|
205 |
-
diff_matrix = np.abs(self._to_numpy(z)[:,None] - zgrid[None,:])
|
206 |
-
#odds
|
207 |
-
idx_peak = np.argmax(pdf_mixture,1)
|
208 |
-
zpeak = zgrid[idx_peak]
|
209 |
-
diff_matrix_upper = np.abs((zpeak+0.05)[:,None] - zgrid[None,:])
|
210 |
-
diff_matrix_lower = np.abs((zpeak-0.05)[:,None] - zgrid[None,:])
|
211 |
-
|
212 |
-
idx = np.argmin(diff_matrix,1)
|
213 |
-
idx_upper = np.argmin(diff_matrix_upper,1)
|
214 |
-
idx_lower = np.argmin(diff_matrix_lower,1)
|
215 |
-
|
216 |
-
p_z_x = np.zeros(shape=(len(z)))
|
217 |
-
odds = np.zeros(shape=(len(z)))
|
218 |
-
|
219 |
-
for ii in range(len(z)):
|
220 |
-
p_z_x[ii] = pdf_mixture[ii,idx[ii]]
|
221 |
-
odds[ii] = pdf_mixture[ii,:idx_upper[ii]].sum() - pdf_mixture[ii,:idx_lower[ii]].sum()
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture, p_z_x, odds
|
226 |
-
else:
|
227 |
|
228 |
-
|
229 |
-
|
|
|
230 |
else:
|
231 |
-
return self._to_numpy(z),self._to_numpy(zerr)
|
232 |
-
|
233 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
pit_list = []
|
236 |
|
237 |
-
self.
|
238 |
-
self.
|
239 |
-
self.
|
240 |
-
self.
|
241 |
|
242 |
input_data = input_data.to(self.device)
|
243 |
|
244 |
|
245 |
-
features = self.
|
246 |
-
mu, logsig, logmix_coeff = self.
|
247 |
|
248 |
logsig = torch.clamp(logsig,-6,2)
|
249 |
sig = torch.exp(logsig)
|
@@ -259,7 +277,8 @@ class Temps_module():
|
|
259 |
|
260 |
return pit_list
|
261 |
|
262 |
-
def
|
|
|
263 |
|
264 |
def measure_crps(cdf, t):
|
265 |
zgrid = np.linspace(0,4,1000)
|
@@ -273,16 +292,16 @@ class Temps_module():
|
|
273 |
|
274 |
crps_list = []
|
275 |
|
276 |
-
self.
|
277 |
-
self.
|
278 |
-
self.
|
279 |
-
self.
|
280 |
|
281 |
input_data = input_data.to(self.device)
|
282 |
|
283 |
|
284 |
-
features = self.
|
285 |
-
mu, logsig, logmix_coeff = self.
|
286 |
logsig = torch.clamp(logsig,-6,2)
|
287 |
sig = torch.exp(logsig)
|
288 |
|
@@ -294,21 +313,19 @@ class Temps_module():
|
|
294 |
z = (mix_coeff * mu).sum(1)
|
295 |
|
296 |
x = np.linspace(0, 4, 1000)
|
297 |
-
|
298 |
for ii in range(len(input_data)):
|
299 |
for i in range(6):
|
300 |
-
|
301 |
|
302 |
-
|
303 |
|
304 |
|
305 |
-
|
306 |
|
307 |
-
crps_value = measure_crps(
|
308 |
|
309 |
|
310 |
|
311 |
return crps_value
|
312 |
|
313 |
-
|
314 |
-
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
import torch
|
|
|
4 |
from torch import nn, optim
|
5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
6 |
from torch.optim import lr_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from scipy.stats import norm
|
8 |
+
from loguru import logger
|
9 |
+
from tqdm import tqdm # Import tqdm for progress bars
|
10 |
+
|
11 |
+
# Local imports
|
12 |
+
from temps.utils import maximum_mean_discrepancy
|
13 |
+
|
14 |
+
|
15 |
+
class TempsModule:
|
16 |
+
"""Class for managing temperature-related models and training."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_f,
|
21 |
+
model_z,
|
22 |
+
batch_size=100,
|
23 |
+
rejection_param=1,
|
24 |
+
da=True,
|
25 |
+
verbose=False,
|
26 |
+
):
|
27 |
+
self.model_z = model_z
|
28 |
+
self.model_f = model_f
|
29 |
+
self.da = da
|
30 |
+
self.verbose = verbose
|
31 |
+
self.ngaussians = model_z.ngaussians
|
32 |
+
|
33 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
self.batch_size = batch_size
|
35 |
+
self.rejection_parameter = rejection_param
|
36 |
+
|
37 |
+
def _get_dataloaders(
|
38 |
+
self, input_data, target_data, input_data_da=None, val_fraction=0.1
|
39 |
+
):
|
40 |
+
"""Create training and validation dataloaders."""
|
41 |
input_data = torch.Tensor(input_data)
|
42 |
target_data = torch.Tensor(target_data)
|
43 |
+
input_data_da = (
|
44 |
+
torch.Tensor(input_data_da)
|
45 |
+
if input_data_da is not None
|
46 |
+
else input_data.clone()
|
47 |
+
)
|
48 |
+
|
49 |
+
dataset = TensorDataset(input_data, input_data_da, target_data)
|
50 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
51 |
+
dataset,
|
52 |
+
[int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
|
53 |
+
)
|
54 |
+
loader_train = DataLoader(
|
55 |
+
train_dataset, batch_size=self.batch_size, shuffle=True
|
56 |
+
)
|
57 |
+
loader_val = DataLoader(val_dataset, batch_size=64, shuffle=True)
|
58 |
|
59 |
return loader_train, loader_val
|
60 |
|
61 |
+
def _loss_function(self, mean, std, logmix, true):
|
62 |
+
"""Compute the loss function."""
|
63 |
+
log_prob = (
|
64 |
+
logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
|
65 |
+
)
|
66 |
+
log_prob = torch.logsumexp(log_prob, dim=1)
|
|
|
67 |
loss = -log_prob.mean()
|
68 |
+
return loss
|
69 |
+
|
70 |
+
def _loss_function_da(self, f1, f2):
|
71 |
+
"""Compute the KL divergence loss for domain adaptation."""
|
72 |
+
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
73 |
loss = kl_loss(f1, f2)
|
74 |
+
return torch.log(loss)
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
def _to_numpy(self, x):
|
77 |
+
"""Convert a tensor to a NumPy array."""
|
78 |
return x.detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
def train(
|
81 |
+
self,
|
82 |
+
input_data,
|
83 |
+
input_data_da,
|
84 |
+
target_data,
|
85 |
+
nepochs=10,
|
86 |
+
step_size=100,
|
87 |
+
val_fraction=0.1,
|
88 |
+
lr=1e-3,
|
89 |
+
weight_decay=0,
|
90 |
+
):
|
91 |
+
"""Train the models using provided data."""
|
92 |
+
self.model_z.train()
|
93 |
+
self.model_f.train()
|
94 |
+
|
95 |
+
loader_train, loader_val = self._get_dataloaders(
|
96 |
+
input_data, target_data, input_data_da, val_fraction
|
97 |
+
)
|
98 |
+
optimizer_z = optim.Adam(
|
99 |
+
self.model_z.parameters(), lr=lr, weight_decay=weight_decay
|
100 |
+
)
|
101 |
+
optimizer_f = optim.Adam(
|
102 |
+
self.model_f.parameters(), lr=lr, weight_decay=weight_decay
|
103 |
+
)
|
104 |
+
|
105 |
+
scheduler_z = lr_scheduler.StepLR(optimizer_z, step_size=step_size, gamma=0.1)
|
106 |
+
scheduler_f = lr_scheduler.StepLR(optimizer_f, step_size=step_size, gamma=0.1)
|
107 |
+
|
108 |
+
self.model_z.to(self.device)
|
109 |
+
self.model_f.to(self.device)
|
110 |
+
|
111 |
+
loss_train, loss_validation = [], []
|
112 |
|
113 |
+
for epoch in range(nepochs):
|
114 |
+
_loss_train, _loss_validation = [], []
|
115 |
+
logger.info(f"Epoch {epoch + 1}/{nepochs} starting...")
|
116 |
+
for input_data, input_data_da, target_data in tqdm(
|
117 |
+
loader_train, desc="Training", unit="batch"
|
118 |
+
):
|
119 |
+
input_data, target_data = input_data.to(self.device), target_data.to(
|
120 |
+
self.device
|
121 |
+
)
|
122 |
if self.da:
|
123 |
input_data_da = input_data_da.to(self.device)
|
124 |
|
125 |
+
optimizer_f.zero_grad()
|
126 |
+
optimizer_z.zero_grad()
|
|
|
127 |
|
128 |
+
features = self.model_f(input_data)
|
129 |
+
features_da = self.model_f(input_data_da) if self.da else None
|
|
|
130 |
|
131 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
132 |
+
logsig = torch.clamp(logsig, -6, 2)
|
133 |
sig = torch.exp(logsig)
|
134 |
|
135 |
+
loss_z = self._loss_function(mu, sig, logmix_coeff, target_data)
|
136 |
+
loss = loss_z + (
|
137 |
+
1e3
|
138 |
+
* maximum_mean_discrepancy(
|
139 |
+
features, features_da, kernel_type="rbf"
|
140 |
+
).sum()
|
141 |
+
if self.da
|
142 |
+
else 0
|
143 |
+
)
|
144 |
+
|
145 |
+
_loss_train.append(loss_z.item())
|
146 |
+
loss.backward()
|
147 |
+
optimizer_f.step()
|
148 |
+
optimizer_z.step()
|
149 |
|
150 |
+
scheduler_f.step()
|
151 |
+
scheduler_z.step()
|
152 |
|
153 |
+
loss_train.append(np.mean(_loss_train))
|
154 |
+
_loss_validation = self._validate(loader_val, target_data)
|
155 |
+
|
156 |
+
logger.info(
|
157 |
+
f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
|
158 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
def _validate(self, loader_val, target_data):
|
161 |
+
"""Validate the model on the validation dataset."""
|
162 |
+
self.model_z.eval()
|
163 |
+
self.model_f.eval()
|
164 |
+
_loss_validation = []
|
165 |
|
166 |
+
with torch.no_grad():
|
167 |
+
for input_data, _, target_data in tqdm(
|
168 |
+
loader_val, desc="Validating", unit="batch"
|
169 |
+
):
|
170 |
input_data = input_data.to(self.device)
|
171 |
target_data = target_data.to(self.device)
|
172 |
|
173 |
+
features = self.model_f(input_data)
|
174 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
175 |
+
logsig = torch.clamp(logsig, -6, 2)
|
|
|
|
|
176 |
sig = torch.exp(logsig)
|
177 |
|
178 |
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
|
179 |
_loss_validation.append(loss_val.item())
|
180 |
|
181 |
+
return _loss_validation
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
def get_features(self, input_data):
|
184 |
+
"""Get features from the model."""
|
185 |
+
self.model_f.eval()
|
|
|
186 |
input_data = input_data.to(self.device)
|
187 |
+
features = self.model_f(input_data)
|
188 |
+
return self._to_numpy(features)
|
|
|
|
|
|
|
189 |
|
190 |
+
def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
|
191 |
+
"""Get the predicted z values and their uncertainties."""
|
192 |
+
logger.info("Predicting photo-z for the input galaxies...")
|
193 |
+
self.model_z.eval()
|
194 |
+
self.model_f.eval()
|
195 |
|
196 |
input_data = input_data.to(self.device)
|
197 |
+
features = self.model_f(input_data)
|
198 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
199 |
+
logsig = torch.clamp(logsig, -6, 2)
|
|
|
|
|
200 |
sig = torch.exp(logsig)
|
201 |
|
202 |
mix_coeff = torch.exp(logmix_coeff)
|
203 |
+
z = (mix_coeff * mu).sum(dim=1)
|
204 |
+
zerr = torch.sqrt(
|
205 |
+
(mix_coeff * sig**2).sum(dim=1)
|
206 |
+
+ (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
|
207 |
+
)
|
208 |
|
209 |
+
mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
if return_pz:
|
212 |
+
logger.info("Returning p(z)")
|
213 |
+
return self._calculate_pdf(z, mu, sig, mix_coeff, return_flag)
|
214 |
else:
|
215 |
+
return self._to_numpy(z), self._to_numpy(zerr)
|
216 |
+
|
217 |
+
def _calculate_pdf(self, z, mu, sig, mix_coeff, return_flag):
|
218 |
+
"""Calculate the probability density function."""
|
219 |
+
zgrid = np.linspace(0, 5, 1000)
|
220 |
+
pz = np.zeros((len(z), len(zgrid)))
|
221 |
+
|
222 |
+
for ii in range(len(z)):
|
223 |
+
for i in range(self.ngaussians):
|
224 |
+
pz[ii] += mix_coeff[ii, i] * norm.pdf(
|
225 |
+
zgrid, mu[ii, i], sig[ii, i]
|
226 |
+
)
|
227 |
+
|
228 |
+
if return_flag:
|
229 |
+
logger.info("Calculating and returning ODDS")
|
230 |
+
pz /= pz.sum(axis=1, keepdims=True)
|
231 |
+
return self._calculate_odds(z, pz, zgrid)
|
232 |
+
return self._to_numpy(z), pz
|
233 |
+
|
234 |
+
def _calculate_odds(self, z, pz, zgrid):
|
235 |
+
"""Calculate odds based on the PDF."""
|
236 |
+
logger.info('Calculating ODDS values')
|
237 |
+
diff_matrix = np.abs(self._to_numpy(z)[:, None] - zgrid[None, :])
|
238 |
+
idx_peak = np.argmax(pz, axis=1)
|
239 |
+
zpeak = zgrid[idx_peak]
|
240 |
+
idx_upper = np.argmin(np.abs((zpeak + 0.05)[:, None] - zgrid[None, :]), axis=1)
|
241 |
+
idx_lower = np.argmin(np.abs((zpeak - 0.05)[:, None] - zgrid[None, :]), axis=1)
|
242 |
+
|
243 |
+
odds = []
|
244 |
+
for jj in range(len(pz)):
|
245 |
+
odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
|
246 |
+
|
247 |
+
odds = np.array(odds)
|
248 |
+
return self._to_numpy(z), pz, odds
|
249 |
+
|
250 |
+
def calculate_pit(self, input_data, target_data):
|
251 |
+
logger.info('Calculating PIT values')
|
252 |
|
253 |
pit_list = []
|
254 |
|
255 |
+
self.model_f = self.model_f.eval()
|
256 |
+
self.model_f = self.model_f.to(self.device)
|
257 |
+
self.model_z = self.model_z.eval()
|
258 |
+
self.model_z = self.model_z.to(self.device)
|
259 |
|
260 |
input_data = input_data.to(self.device)
|
261 |
|
262 |
|
263 |
+
features = self.model_f(input_data)
|
264 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
265 |
|
266 |
logsig = torch.clamp(logsig,-6,2)
|
267 |
sig = torch.exp(logsig)
|
|
|
277 |
|
278 |
return pit_list
|
279 |
|
280 |
+
def calculate_crps(self, input_data, target_data):
|
281 |
+
logger.info('Calculating CRPS values')
|
282 |
|
283 |
def measure_crps(cdf, t):
|
284 |
zgrid = np.linspace(0,4,1000)
|
|
|
292 |
|
293 |
crps_list = []
|
294 |
|
295 |
+
self.model_f = self.model_f.eval()
|
296 |
+
self.model_f = self.model_f.to(self.device)
|
297 |
+
self.model_z = self.model_z.eval()
|
298 |
+
self.model_z = self.model_z.to(self.device)
|
299 |
|
300 |
input_data = input_data.to(self.device)
|
301 |
|
302 |
|
303 |
+
features = self.model_f(input_data)
|
304 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
305 |
logsig = torch.clamp(logsig,-6,2)
|
306 |
sig = torch.exp(logsig)
|
307 |
|
|
|
313 |
z = (mix_coeff * mu).sum(1)
|
314 |
|
315 |
x = np.linspace(0, 4, 1000)
|
316 |
+
pz = np.zeros(shape=(len(target_data), len(x)))
|
317 |
for ii in range(len(input_data)):
|
318 |
for i in range(6):
|
319 |
+
pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
320 |
|
321 |
+
pz = pz / pz.sum(1)[:,None]
|
322 |
|
323 |
|
324 |
+
cdf_z = np.cumsum(pz,1)
|
325 |
|
326 |
+
crps_value = measure_crps(cdf_z, target_data)
|
327 |
|
328 |
|
329 |
|
330 |
return crps_value
|
331 |
|
|
|
|
temps/temps_arch.py
CHANGED
@@ -20,52 +20,46 @@ class EncoderPhotometry(nn.Module):
|
|
20 |
nn.Linear(50, 20),
|
21 |
nn.Dropout(dropout_prob),
|
22 |
nn.ReLU(),
|
23 |
-
nn.Linear(20, 10)
|
24 |
)
|
25 |
-
|
26 |
def forward(self, x):
|
27 |
f = self.features(x)
|
28 |
-
f =
|
29 |
return f
|
30 |
|
31 |
-
|
32 |
|
33 |
class MeasureZ(nn.Module):
|
34 |
def __init__(self, num_gauss=10, dropout_prob=0):
|
35 |
super(MeasureZ, self).__init__()
|
36 |
-
|
37 |
-
self.ngaussians=num_gauss
|
38 |
self.measure_mu = nn.Sequential(
|
39 |
nn.Linear(10, 20),
|
40 |
nn.Dropout(dropout_prob),
|
41 |
nn.ReLU(),
|
42 |
-
nn.Linear(20, num_gauss)
|
43 |
)
|
44 |
|
45 |
self.measure_coeffs = nn.Sequential(
|
46 |
nn.Linear(10, 20),
|
47 |
nn.Dropout(dropout_prob),
|
48 |
nn.ReLU(),
|
49 |
-
nn.Linear(20, num_gauss)
|
50 |
)
|
51 |
|
52 |
self.measure_sigma = nn.Sequential(
|
53 |
nn.Linear(10, 20),
|
54 |
nn.Dropout(dropout_prob),
|
55 |
nn.ReLU(),
|
56 |
-
nn.Linear(20, num_gauss)
|
57 |
)
|
58 |
-
|
59 |
-
|
60 |
def forward(self, f):
|
61 |
mu = self.measure_mu(f)
|
62 |
sigma = self.measure_sigma(f)
|
63 |
logmix_coeff = self.measure_coeffs(f)
|
64 |
-
|
65 |
-
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
|
66 |
-
|
67 |
-
return mu, sigma, logmix_coeff
|
68 |
|
69 |
-
|
70 |
-
|
71 |
|
|
|
|
20 |
nn.Linear(50, 20),
|
21 |
nn.Dropout(dropout_prob),
|
22 |
nn.ReLU(),
|
23 |
+
nn.Linear(20, 10),
|
24 |
)
|
25 |
+
|
26 |
def forward(self, x):
|
27 |
f = self.features(x)
|
28 |
+
f = F.log_softmax(f, dim=1)
|
29 |
return f
|
30 |
|
|
|
31 |
|
32 |
class MeasureZ(nn.Module):
|
33 |
def __init__(self, num_gauss=10, dropout_prob=0):
|
34 |
super(MeasureZ, self).__init__()
|
35 |
+
|
36 |
+
self.ngaussians = num_gauss
|
37 |
self.measure_mu = nn.Sequential(
|
38 |
nn.Linear(10, 20),
|
39 |
nn.Dropout(dropout_prob),
|
40 |
nn.ReLU(),
|
41 |
+
nn.Linear(20, num_gauss),
|
42 |
)
|
43 |
|
44 |
self.measure_coeffs = nn.Sequential(
|
45 |
nn.Linear(10, 20),
|
46 |
nn.Dropout(dropout_prob),
|
47 |
nn.ReLU(),
|
48 |
+
nn.Linear(20, num_gauss),
|
49 |
)
|
50 |
|
51 |
self.measure_sigma = nn.Sequential(
|
52 |
nn.Linear(10, 20),
|
53 |
nn.Dropout(dropout_prob),
|
54 |
nn.ReLU(),
|
55 |
+
nn.Linear(20, num_gauss),
|
56 |
)
|
57 |
+
|
|
|
58 |
def forward(self, f):
|
59 |
mu = self.measure_mu(f)
|
60 |
sigma = self.measure_sigma(f)
|
61 |
logmix_coeff = self.measure_coeffs(f)
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:, None]
|
|
|
64 |
|
65 |
+
return mu, sigma, logmix_coeff
|
temps/utils.py
CHANGED
@@ -3,113 +3,22 @@ import pandas as pd
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
-
from
|
7 |
|
8 |
-
def nmad(data):
|
9 |
-
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
10 |
-
|
11 |
-
def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
|
12 |
-
|
13 |
-
def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
|
14 |
-
#plot properties
|
15 |
-
plt.rcParams['font.family'] = 'serif'
|
16 |
-
plt.rcParams['font.size'] = 12
|
17 |
-
|
18 |
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
|
22 |
-
print(bin_edges)
|
23 |
-
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
24 |
-
plt.figure(figsize=(6, 5))
|
25 |
-
|
26 |
-
for i, df in enumerate(df_list):
|
27 |
-
ydata, xlab = [], []
|
28 |
-
|
29 |
-
for k in range(len(bin_edges)-1):
|
30 |
-
edge_min = bin_edges[k]
|
31 |
-
edge_max = bin_edges[k+1]
|
32 |
-
|
33 |
-
mean_mag = (edge_max + edge_min) / 2
|
34 |
-
|
35 |
-
if type_bin == 'bin':
|
36 |
-
df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
|
37 |
-
elif type_bin == 'cum':
|
38 |
-
df_plot = df[(df[xvariable] < edge_max)]
|
39 |
-
else:
|
40 |
-
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
41 |
-
|
42 |
-
xlab.append(mean_mag)
|
43 |
-
if metric == 'sig68':
|
44 |
-
ydata.append(sigma68(df_plot.zwerr))
|
45 |
-
elif metric == 'bias':
|
46 |
-
ydata.append(np.mean(df_plot.zwerr))
|
47 |
-
elif metric == 'nmad':
|
48 |
-
ydata.append(nmad(df_plot.zwerr))
|
49 |
-
elif metric == 'outliers':
|
50 |
-
ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
|
51 |
-
|
52 |
-
print(ydata)
|
53 |
-
color = cmap(i) # Get a different color for each dataframe
|
54 |
-
plt.plot(xlab, ydata, ls='-', marker='.', lw=1, label=f'{label_list[i]}', color=color)
|
55 |
-
|
56 |
-
if xvariable == 'VISmag':
|
57 |
-
xvariable_lab = 'VIS'
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
|
62 |
-
plt.xlabel(f'{xvariable_lab}', fontsize=16)
|
63 |
-
plt.grid(False)
|
64 |
-
plt.legend()
|
65 |
-
|
66 |
-
if save==True:
|
67 |
-
plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
|
68 |
-
plt.show()
|
69 |
-
|
70 |
-
|
71 |
-
def plot_nz(df, bins=np.arange(0,5,0.2)):
|
72 |
-
kwargs=dict( bins=bins,alpha=0.5)
|
73 |
-
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
|
74 |
-
counts, _, =np.histogram(df.z.values, bins=bins)
|
75 |
-
|
76 |
-
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
|
77 |
-
|
78 |
-
#plt.legend(fontsize=14)
|
79 |
-
plt.xlabel(r'Redshift', fontsize=14)
|
80 |
-
plt.ylabel(r'Counts', fontsize=14)
|
81 |
-
plt.yscale('log')
|
82 |
-
|
83 |
-
plt.show()
|
84 |
-
|
85 |
-
return
|
86 |
-
|
87 |
-
|
88 |
-
def plot_scatter(df, sample='specz', save=True):
|
89 |
-
# Calculate the point density
|
90 |
-
xy = np.vstack([df.zs.values,df.z.values])
|
91 |
-
zd = gaussian_kde(xy)(xy)
|
92 |
-
|
93 |
-
fig, ax = plt.subplots()
|
94 |
-
plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
|
95 |
-
plt.xlim(0,5)
|
96 |
-
plt.ylim(0,5)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
plt.xticks(fontsize = 12)
|
102 |
-
plt.yticks(fontsize = 12)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
plt.show()
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
def maximum_mean_discrepancy(x, y, kernel_type=
|
113 |
"""
|
114 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
115 |
|
@@ -130,7 +39,8 @@ def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num
|
|
130 |
mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
|
131 |
return mmd_loss
|
132 |
|
133 |
-
|
|
|
134 |
"""
|
135 |
Compute the kernel matrix based on the chosen kernel type.
|
136 |
|
@@ -151,73 +61,77 @@ def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
|
|
151 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
152 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
153 |
|
154 |
-
kernel_input = (x - y).pow(2).mean(2)
|
155 |
|
156 |
-
if kernel_type ==
|
157 |
kernel_matrix = kernel_input
|
158 |
-
elif kernel_type ==
|
159 |
kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
|
160 |
-
elif kernel_type ==
|
161 |
kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
|
162 |
-
elif kernel_type ==
|
163 |
kernel_matrix = torch.tanh(kernel_mul * kernel_input)
|
164 |
else:
|
165 |
-
raise ValueError(
|
|
|
|
|
166 |
|
167 |
return kernel_matrix
|
168 |
|
169 |
|
170 |
-
def select_cut(
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
if (completenss_lim is None)&(nmad_lim is None)&(outliers_lim is None):
|
178 |
-
raise(ValueError("Select at least one cut"))
|
179 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
180 |
raise ValueError("Select only one cut at a time")
|
181 |
-
|
182 |
else:
|
183 |
-
bin_edges = stats.mstats.mquantiles(df.
|
184 |
-
scatter, eta, cmptnss, nobj = [],[],[], []
|
185 |
|
186 |
-
for k in range(len(bin_edges)-1):
|
187 |
edge_min = bin_edges[k]
|
188 |
-
edge_max = bin_edges[k+1]
|
189 |
|
190 |
-
df_bin = df[(df.
|
191 |
-
|
192 |
|
193 |
-
cmptnss.append(np.round(len(df_bin)/len(df),2)*100)
|
194 |
scatter.append(nmad(df_bin.zwerr))
|
195 |
-
eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df_bin)*100)
|
196 |
nobj.append(len(df_bin))
|
197 |
-
|
198 |
-
dfcuts = pd.DataFrame(
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
if completenss_lim is not None:
|
201 |
-
|
202 |
-
selected_cut = dfcuts[dfcuts[
|
203 |
-
|
204 |
-
|
205 |
elif nmad_lim is not None:
|
206 |
-
|
207 |
-
selected_cut = dfcuts[dfcuts[
|
208 |
|
209 |
-
|
210 |
elif outliers_lim is not None:
|
211 |
-
|
212 |
-
selected_cut = dfcuts[dfcuts[
|
213 |
|
|
|
|
|
|
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
if return_df==True:
|
219 |
-
return df_cut, selected_cut['flagcut'], dfcuts
|
220 |
else:
|
221 |
-
return selected_cut[
|
222 |
-
|
223 |
-
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
+
from loguru import logger
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
def caluclate_eta(df):
|
10 |
+
return len(df[np.abs(df.zwerr)>0.15])/len(df) *100
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
def nmad(data):
|
14 |
+
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
15 |
|
|
|
|
|
16 |
|
17 |
+
def sigma68(data):
|
18 |
+
return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
22 |
"""
|
23 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
24 |
|
|
|
39 |
mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
|
40 |
return mmd_loss
|
41 |
|
42 |
+
|
43 |
+
def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
44 |
"""
|
45 |
Compute the kernel matrix based on the chosen kernel type.
|
46 |
|
|
|
61 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
62 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
63 |
|
64 |
+
kernel_input = (x - y).pow(2).mean(2)
|
65 |
|
66 |
+
if kernel_type == "linear":
|
67 |
kernel_matrix = kernel_input
|
68 |
+
elif kernel_type == "poly":
|
69 |
kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
|
70 |
+
elif kernel_type == "rbf":
|
71 |
kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
|
72 |
+
elif kernel_type == "sigmoid":
|
73 |
kernel_matrix = torch.tanh(kernel_mul * kernel_input)
|
74 |
else:
|
75 |
+
raise ValueError(
|
76 |
+
"Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'."
|
77 |
+
)
|
78 |
|
79 |
return kernel_matrix
|
80 |
|
81 |
|
82 |
+
def select_cut(
|
83 |
+
df, completenss_lim=None, nmad_lim=None, outliers_lim=None, return_df=False
|
84 |
+
):
|
85 |
+
|
86 |
+
if (completenss_lim is None) & (nmad_lim is None) & (outliers_lim is None):
|
87 |
+
raise (ValueError("Select at least one cut"))
|
|
|
|
|
|
|
88 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
89 |
raise ValueError("Select only one cut at a time")
|
90 |
+
|
91 |
else:
|
92 |
+
bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
|
93 |
+
scatter, eta, cmptnss, nobj = [], [], [], []
|
94 |
|
95 |
+
for k in range(len(bin_edges) - 1):
|
96 |
edge_min = bin_edges[k]
|
97 |
+
edge_max = bin_edges[k + 1]
|
98 |
|
99 |
+
df_bin = df[(df.odds > edge_min)]
|
|
|
100 |
|
101 |
+
cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
|
102 |
scatter.append(nmad(df_bin.zwerr))
|
103 |
+
eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
|
104 |
nobj.append(len(df_bin))
|
105 |
+
|
106 |
+
dfcuts = pd.DataFrame(
|
107 |
+
data=np.c_[
|
108 |
+
np.round(bin_edges[:-1], 5),
|
109 |
+
np.round(nobj, 1),
|
110 |
+
np.round(cmptnss, 1),
|
111 |
+
np.round(scatter, 3),
|
112 |
+
np.round(eta, 2),
|
113 |
+
],
|
114 |
+
columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
|
115 |
+
)
|
116 |
+
|
117 |
if completenss_lim is not None:
|
118 |
+
logger.info("Selecting cut based on completeness")
|
119 |
+
selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
|
120 |
+
|
|
|
121 |
elif nmad_lim is not None:
|
122 |
+
logger.info("Selecting cut based on nmad")
|
123 |
+
selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
|
124 |
|
|
|
125 |
elif outliers_lim is not None:
|
126 |
+
logger.info("Selecting cut based on outliers")
|
127 |
+
selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
|
128 |
|
129 |
+
logger.info(
|
130 |
+
f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
|
131 |
+
)
|
132 |
|
133 |
+
df_cut = df[(df.odds > selected_cut["flagcut"])]
|
134 |
+
if return_df == True:
|
135 |
+
return df_cut, selected_cut["flagcut"], dfcuts
|
|
|
|
|
136 |
else:
|
137 |
+
return selected_cut["flagcut"], dfcuts
|
|
|
|