lauracabayol commited on
Commit
57fa8fc
·
1 Parent(s): 692f707

clear code and notebooks

Browse files
notebooks/Feature_space.py DELETED
@@ -1,494 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
- # kernelspec:
10
- # display_name: insight
11
- # language: python
12
- # name: insight
13
- # ---
14
-
15
- # # DOMAIN ADAPTATION INTUITION
16
-
17
- # %load_ext autoreload
18
- # %autoreload 2
19
-
20
- import pandas as pd
21
- import numpy as np
22
- import os
23
- from astropy.io import fits
24
- from astropy.table import Table
25
- import torch
26
-
27
- #matplotlib settings
28
- from matplotlib import rcParams
29
- import matplotlib.pyplot as plt
30
- rcParams["mathtext.fontset"] = "stix"
31
- rcParams["font.family"] = "STIXGeneral"
32
-
33
- # +
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
-
38
- from archive import archive
39
- from utils import nmad
40
- from temps_arch import EncoderPhotometry, MeasureZ
41
- from temps import Temps_module
42
- from plots import plot_nz
43
- # -
44
-
45
- # ## LOAD DATA
46
-
47
- #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
-
51
- # +
52
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
53
-
54
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
55
- cat = Table(hdu_list[1].data).to_pandas()
56
- cat = cat[cat['FLAG_PHOT']==0]
57
- cat = cat[cat['mu_class_L07']==1]
58
-
59
- cat['SNR_VIS'] = cat.FLUX_VIS / cat.FLUXERR_VIS
60
- #cat = cat[cat.SNR_VIS>10]
61
- # -
62
-
63
- ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
64
- specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
65
- ID = cat['ID']
66
- VISmag = cat['MAG_VIS']
67
- zsflag = cat['reliable_S15']
68
- cat['ztarget']=ztarget
69
- cat['specz_or_photo']=specz_or_photo
70
-
71
- # ### EXTRACT PHOTOMETRY
72
-
73
- photoz_archive = archive(path = parent_dir,only_zspec=False)
74
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
75
- col, colerr = photoz_archive._to_colors(f, ferr)
76
-
77
- # ### MEASURE FEATURES
78
-
79
- features_all = np.zeros((3,len(cat),10))
80
- for il, lab in enumerate(['z','L15','DA']):
81
-
82
- nn_features = EncoderPhotometry()
83
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
84
-
85
- features = nn_features(torch.Tensor(col))
86
- features = features.detach().cpu().numpy()
87
-
88
- features_all[il]=features
89
-
90
-
91
- # ### TRAIN AUTOENCODER TO REDUCE TO 2 DIMENSIONS
92
-
93
- import torch
94
- from torch import nn
95
- class Autoencoder(nn.Module):
96
- def __init__(self, input_dim, latent_dim):
97
- super(Autoencoder, self).__init__()
98
- # Encoder layers
99
- self.encoder = nn.Sequential(
100
- nn.Linear(input_dim, 100),
101
- nn.ReLU(),
102
- nn.Linear(100, 50),
103
- nn.ReLU(),
104
- nn.Linear(50, latent_dim)
105
- )
106
- # Decoder layers
107
- self.decoder = nn.Sequential(
108
- nn.Linear(latent_dim, 50),
109
- nn.ReLU(),
110
- nn.Linear(50, 100),
111
- nn.ReLU(),
112
- nn.Linear(100, input_dim),
113
- )
114
-
115
- def forward(self, x):
116
- x = self.encoder(x)
117
- y = self.decoder(x)
118
- return y,x
119
-
120
-
121
- # +
122
- from torch.utils.data import DataLoader, dataset, TensorDataset
123
-
124
- ds =TensorDataset(torch.Tensor(features_all[0]))
125
- train_loader = DataLoader(ds, batch_size=100, shuffle=True, drop_last=False)
126
-
127
- # -
128
-
129
- import torch.optim as optim
130
- autoencoder = Autoencoder(input_dim=10,
131
- latent_dim=2)
132
- criterion = nn.L1Loss()
133
- optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
134
-
135
- # +
136
- # Define the number of epochs
137
- num_epochs = 100
138
- for epoch in range(num_epochs):
139
- running_loss = 0.0
140
- for data in train_loader: # Assuming 'train_loader' is your DataLoader
141
- # Forward pass
142
- outputs,f1 = autoencoder(data[0])
143
-
144
- loss_autoencoder = criterion(outputs, data[0])
145
- optimizer.zero_grad()
146
-
147
- # Backward pass
148
- loss_autoencoder.backward()
149
-
150
- # Update the weights
151
- optimizer.step()
152
-
153
- # Accumulate the loss
154
- running_loss += loss_autoencoder.item()
155
-
156
- # Print the average loss for the epoch
157
- print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, running_loss / len(train_loader)))
158
-
159
- print('Training finished')
160
-
161
- # -
162
-
163
- # #### EVALUTATE AUTOENCODER
164
-
165
- # + [markdown] jupyter={"source_hidden": true}
166
- # cat.to_csv('features_cat.csv', header=True, sep=',')
167
- # -
168
-
169
- indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
170
-
171
- features_all_reduced = np.zeros(shape=(3,len(cat),2))
172
- for i in range(3):
173
- _, features = autoencoder(torch.Tensor(features_all[i]))
174
- features_all_reduced[i] = features.detach().cpu().numpy()
175
-
176
- # ### Plot the features
177
-
178
- start = 0
179
- end = len(cat)
180
- all_values = set(range(start, end))
181
- values_not_in_indexes_specz = all_values - set(indexes_specz)
182
- indexes_nospecz = sorted(values_not_in_indexes_specz)
183
-
184
- # +
185
- import seaborn as sns
186
-
187
- # Create subplots with three panels
188
- fig, axs = plt.subplots(1, 3, figsize=(15, 5))
189
-
190
- # Set style for all subplots
191
- sns.set_style("white")
192
-
193
- # First subplot
194
- sns.kdeplot(x=features_all_reduced[0, indexes_nospecz,0],
195
- y=features_all_reduced[0, indexes_nospecz,1],
196
- clip=(-150, 150),
197
- ax=axs[0],
198
- color='salmon')
199
- sns.kdeplot(x=features_all_reduced[0, indexes_specz,0],
200
- y=features_all_reduced[0, indexes_specz,1],
201
- clip=(-150, 150),
202
- ax=axs[0],
203
- color='lightskyblue')
204
-
205
- axs[0].set_xlim(-150, 150)
206
- axs[0].set_ylim(-150, 150)
207
- axs[0].set_title(r'Trained on $z_{\rm s}$')
208
-
209
- # Second subplot
210
- sns.kdeplot(x=features_all_reduced[1, indexes_nospecz, 0],
211
- y=features_all_reduced[1, indexes_nospecz, 1],
212
- clip=(-50, 50),
213
- ax=axs[1],
214
- color='salmon')
215
- sns.kdeplot(x=features_all_reduced[1, indexes_specz, 0],
216
- y=features_all_reduced[1, indexes_specz,1],
217
- clip=(-50, 50),
218
- ax=axs[1],
219
- color='lightskyblue')
220
- axs[1].set_xlim(-50, 50)
221
- axs[1].set_ylim(-50, 50)
222
- axs[1].set_title('Trained on L15')
223
-
224
- # Third subplot
225
- features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
226
- sns.kdeplot(x=features_all_reduced_nospecz[:, 0],
227
- y=features_all_reduced_nospecz[:, 1],
228
- clip=(-1, 5),
229
- ax=axs[2],
230
- color='salmon',
231
- label='Wide-field sample')
232
- sns.kdeplot(x=features_all_reduced_specz[:, 0],
233
- y=features_all_reduced_specz[:, 1],
234
- clip=(-1, 5),
235
- ax=axs[2],
236
- color='lightskyblue',
237
- label=r'$z_{\rm s}$ sample')
238
- axs[2].set_xlim(-2, 5)
239
- axs[2].set_ylim(-2, 5)
240
- axs[2].set_title('TEMPS')
241
-
242
- axs[0].set_xlabel('Feature 1')
243
- axs[1].set_xlabel('Feature 1')
244
- axs[2].set_xlabel('Feature 1')
245
- axs[0].set_ylabel('Feature 2')
246
-
247
- # Create custom legend with desired colors
248
- legend_labels = ['Wide-field sample', r'$z_{\rm s}$ sample']
249
- legend_handles = [plt.Line2D([0], [0], color='salmon', lw=2),
250
- plt.Line2D([0], [0], color='lightskyblue', lw=2)]
251
- axs[2].legend(legend_handles, legend_labels, loc='upper right', fontsize=16)
252
- # Adjust layout
253
- plt.tight_layout()
254
-
255
- plt.savefig('Contourplot.pdf', bbox_inches='tight')
256
- plt.show()
257
-
258
- # -
259
-
260
-
261
-
262
-
263
-
264
-
265
-
266
- np.savetxt('features.txt',features_all_reduced.reshape(3*164816, 2))
267
-
268
-
269
-
270
-
271
-
272
-
273
-
274
-
275
-
276
-
277
-
278
- # +
279
- photoz_archive = archive(path = parent_dir,only_zspec=False)
280
-
281
- fig, ax = plt.subplots(ncols = 3, figsize=(15,4), sharex=True, sharey=True)
282
- colors = ['navy', 'goldenrod']
283
- titles = [r'Training: $z_s$', r'Training: L15',r'Training: $z_s$ + DA']
284
- x_min, x_max = -5,5
285
- y_min, y_max = -5,5
286
- x_grid, y_grid = np.meshgrid(np.linspace(x_min, x_max, 10), np.linspace(y_min, y_max, 10))
287
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
288
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
289
- for il, lab in enumerate(['z','L15','DA']):
290
-
291
-
292
- nn_features = EncoderPhotometry()
293
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
294
-
295
- for it, target_type in enumerate(['L15','zs']):
296
- if target_type=='zs':
297
- cat_sub = photoz_archive._select_only_zspec(cat)
298
- cat_sub = photoz_archive._clean_zspec_sample(cat_sub)
299
-
300
- elif target_type=='L15':
301
- cat_sub = photoz_archive._exclude_only_zspec(cat)
302
- else:
303
- assert False
304
-
305
- cat_sub = photoz_archive._clean_photometry(cat_sub)
306
- print(cat_sub.shape)
307
-
308
-
309
-
310
- f, ferr = photoz_archive._extract_fluxes(cat_sub)
311
- col, colerr = photoz_archive._to_colors(f, ferr)
312
-
313
- features = nn_features(torch.Tensor(col))
314
- features = features.detach().cpu().numpy()
315
-
316
-
317
- #xy = np.vstack([features[:1000,0], features[:1000,1]])
318
- #zd = gaussian_kde(xy)(xy)
319
- #ax[il].scatter(features[:1000,0], features[:1000,1],c=zd, s=3)
320
-
321
- xy = np.vstack([features[:,0], features[:,1]])
322
- density_estimation = gaussian_kde(xy)
323
-
324
- # Define grid for plotting density lines
325
-
326
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
327
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
328
-
329
- # Plot contour lines representing density
330
- ax[il].contour(x_grid, y_grid, density_grid, colors=colors[it], label = f'{target_type}')
331
-
332
-
333
-
334
- ax[il].set_title(titles[il])
335
- ax[il].set_xlim(-5,5)
336
- ax[il].set_ylim(-5,5)
337
-
338
-
339
- ax[0].set_ylabel('Feature 1', fontsize=14)
340
- #plt.ylabel('Feature 2', fontsize=14)
341
-
342
- #assert False
343
-
344
-
345
-
346
-
347
-
348
-
349
- # -
350
-
351
- H
352
-
353
- H
354
-
355
- xedges
356
-
357
- yedges
358
-
359
- # +
360
- import matplotlib.colors as colors
361
- from matplotlib import path
362
- import numpy as np
363
- from matplotlib import pyplot as plt
364
- try:
365
- from astropy.convolution import Gaussian2DKernel, convolve
366
- astro_smooth = True
367
- except ImportError as IE:
368
- astro_smooth = False
369
-
370
- np.random.seed(123)
371
- #t = np.linspace(-5,1.2,1000)
372
- x = features[:1000,0]
373
- y = features[:1000,1]
374
-
375
- H, xedges, yedges = np.histogram2d(x,y, bins=(10,10))
376
- xmesh, ymesh = np.meshgrid(xedges[:-1], yedges[:-1])
377
-
378
- # Smooth the contours (if astropy is installed)
379
- if astro_smooth:
380
- kernel = Gaussian2DKernel(x_stddev=1.)
381
- H=convolve(H,kernel)
382
-
383
- fig,ax = plt.subplots(1, figsize=(7,6))
384
- clevels = ax.contour(xmesh,ymesh,H.T,lw=.9,cmap='winter')#,zorder=90)
385
- ax.scatter(x,y,s=3)
386
- #ax.set_xlim(-20,5)
387
- #ax.set_ylim(-20,5)
388
-
389
- # Identify points within contours
390
- #p = clevels.collections[0].get_paths()
391
- #inside = np.full_like(x,False,dtype=bool)
392
- #for level in p:
393
- # inside |= level.contains_points(zip(*(x,y)))
394
-
395
- #ax.plot(x[~inside],y[~inside],'kx')
396
- #plt.show(block=False)
397
- # -
398
-
399
- density_grid
400
-
401
- features.shape, zd.shape
402
-
403
- # + jupyter={"outputs_hidden": true}
404
- xy = np.vstack([features[:,0], features[:,1]])
405
- zd = gaussian_kde(xy)(xy)
406
- plt.scatter(features[:,0], features[:,1],c=zd)
407
-
408
-
409
- # +
410
- # Make the base corner plot
411
- figure = corner.corner(features[:,:2], quantiles=[0.16, 0.84], show_titles=False, color ='crimson')
412
- corner.corner(samples2, fig=fig)
413
- ndim=2
414
- # Extract the axes
415
- axes = np.array(figure.axes).reshape((ndim, ndim))
416
-
417
-
418
- for a in axes[np.triu_indices(ndim)]:
419
- a.remove()
420
-
421
- # +
422
- import numpy as np
423
- import matplotlib.pyplot as plt
424
- from scipy.stats import gaussian_kde
425
-
426
- # Assuming 'features' is your data array with shape (n_samples, 2)
427
-
428
- # Calculate the density estimate
429
- xy = np.vstack([features[:,0], features[:,1]])
430
- density_estimation = gaussian_kde(xy)
431
-
432
- # Define grid for plotting density lines
433
-
434
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
435
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
436
-
437
- # Plot contour lines representing density
438
- plt.contour(x_grid, y_grid, density_grid, colors='black')
439
-
440
- # Optionally, you can add a scatter plot on top of the density lines for better visualization
441
- #plt.scatter(features[:,0], features[:,1], color='blue', alpha=0.5)
442
-
443
- # Set labels and title
444
- plt.xlabel('Feature 1')
445
- plt.ylabel('Feature 2')
446
- plt.title('Density Lines Plot')
447
-
448
- # Show plot
449
- plt.show()
450
-
451
- # -
452
-
453
-
454
-
455
-
456
-
457
- corner_plot = corner.corner(Arinyo_preds,
458
- labels=[r'$b$', r'$\beta$', '$q_1$', '$k_{vav}$','$a_v$','$b_v$','$k_p$','$q_2$'],
459
- truths=Arinyo_coeffs_central[test_snap],
460
- truth_color='crimson')
461
-
462
- import corner
463
- figure = corner.corner(features, quantiles=[0.16, 0.5, 0.84], show_titles=False)
464
- axes = np.array(fig.axes).reshape((ndim, ndim))
465
- for a in axes[np.triu_indices(ndim)]:
466
- a.remove()
467
-
468
-
469
-
470
- # +
471
- # My data
472
- x = features[:,0]
473
- y = features[:,1]
474
-
475
- # Peform the kernel density estimate
476
- k = stats.gaussian_kde(np.vstack([x, y]))
477
- xi, yi = np.mgrid[-5:5,-5:5]
478
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
479
-
480
-
481
-
482
- fig = plt.figure()
483
- ax = fig.gca()
484
-
485
-
486
- CS = ax.contour(xi, yi, zi.reshape(xi.shape), colors='crimson')
487
-
488
- ax.set_xlim(-5, 5)
489
- ax.set_ylim(-5, 5)
490
-
491
- plt.show()
492
- # -
493
-
494
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig2_NMAD.py DELETED
@@ -1,170 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # formats: ipynb,py:percent
5
- # text_representation:
6
- # extension: .py
7
- # format_name: percent
8
- # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
- # kernelspec:
11
- # display_name: insight
12
- # language: python
13
- # name: insight
14
- # ---
15
-
16
- # %% [markdown]
17
- # # FIGURE 2 IN THE PAPER
18
-
19
- # %% [markdown]
20
- # ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
21
-
22
- # %% [markdown]
23
- # ### LOAD PYTHON MODULES
24
-
25
- # %%
26
- # %load_ext autoreload
27
- # %autoreload 2
28
-
29
- # %%
30
- import pandas as pd
31
- import numpy as np
32
- import os
33
- from astropy.io import fits
34
- from astropy.table import Table
35
- import torch
36
-
37
- # %%
38
- #matplotlib settings
39
- from matplotlib import rcParams
40
- import matplotlib.pyplot as plt
41
- rcParams["mathtext.fontset"] = "stix"
42
- rcParams["font.family"] = "STIXGeneral"
43
-
44
-
45
- # %%
46
- #insight modules
47
- import sys
48
- sys.path.append('../temps')
49
-
50
- from archive import archive
51
- from utils import nmad
52
- from temps_arch import EncoderPhotometry, MeasureZ
53
- from temps import Temps_module
54
-
55
-
56
-
57
- # %%
58
- eval_methods=True
59
-
60
- # %% [markdown]
61
- # ### LOAD DATA
62
-
63
- # %%
64
- #define here the directory containing the photometric catalogues
65
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
66
- modules_dir = '../data/models/'
67
-
68
- # %%
69
- #load catalogue and apply cuts
70
-
71
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
72
-
73
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
74
- cat = Table(hdu_list[1].data).to_pandas()
75
- cat = cat[cat['FLAG_PHOT']==0]
76
- cat = cat[cat['mu_class_L07']==1]
77
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
78
- cat = cat[cat['MAG_VIS']<25]
79
-
80
-
81
-
82
- # %%
83
- ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
84
- specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
85
- ID = cat['ID']
86
- VISmag = cat['MAG_VIS']
87
- zsflag = cat['reliable_S15']
88
-
89
- # %%
90
- photoz_archive = archive(path = parent_dir,only_zspec=False)
91
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
92
- col, colerr = photoz_archive._to_colors(f, ferr)
93
-
94
- # %% [markdown]
95
- # ### EVALUATE USING TRAINED MODELS
96
-
97
- # %%
98
- if eval_methods:
99
-
100
- dfs = {}
101
- for il, lab in enumerate(['z','L15','DA']):
102
-
103
- nn_features = EncoderPhotometry()
104
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
105
- nn_z = MeasureZ(num_gauss=6)
106
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
107
-
108
- temps = Temps_module(nn_features, nn_z)
109
-
110
- z,zerr, zmode,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
111
- return_pz=True)
112
- # Create a DataFrame with the desired columns
113
- df = pd.DataFrame(np.c_[ID, VISmag,z, zmode, flag, ztarget,zsflag,zerr, specz_or_photo],
114
- columns=['ID','VISmag','z', 'zmode','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
115
-
116
- # Calculate additional columns or operations if needed
117
- df['zwerr'] = (df.zmode - df.ztarget) / (1 + df.ztarget)
118
-
119
- # Drop any rows with NaN values
120
- df = df.dropna()
121
-
122
- # Assign the DataFrame to a key in the dictionary
123
- dfs[lab] = df
124
-
125
-
126
-
127
- # %%
128
- dfs['z']['zwerr'] = (dfs['z'].z - dfs['z'].ztarget) / (1 + dfs['z'].ztarget)
129
- dfs['L15']['zwerr'] = (dfs['L15'].z - dfs['L15'].ztarget) / (1 + dfs['L15'].ztarget)
130
- dfs['DA']['zwerr'] = (dfs['DA'].z - dfs['DA'].ztarget) / (1 + dfs['DA'].ztarget)
131
-
132
- # %% [markdown]
133
- # ### LOAD CATALOGUES FROM PREVIOUS TRAINING
134
-
135
- # %%
136
- if not eval_methods:
137
- dfs = {}
138
- dfs['z'] = pd.read_csv(os.path.join(parent_dir, 'predictions_specztraining.csv'), header=0)
139
- dfs['L15'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczL15training.csv'), header=0)
140
- dfs['DA'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczDAtraining.csv'), header=0)
141
-
142
-
143
- # %% [markdown]
144
- # ### MAKE PLOT
145
-
146
- # %%
147
- plot_photoz(df_list,
148
- nbins=8,
149
- xvariable='VISmag',
150
- metric='nmad',
151
- type_bin='bin',
152
- label_list = ['zs','zs+L15',r'TEMPS'],
153
- save=False,
154
- samp='L15'
155
- )
156
-
157
- # %%
158
- plot_photoz(df_list,
159
- nbins=8,
160
- xvariable='VISmag',
161
- metric='outliers',
162
- type_bin='bin',
163
- label_list = ['zs','zs+L15',r'TEMPS'],
164
- save=False,
165
- samp='L15'
166
- )
167
-
168
- # %%
169
-
170
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig3_PIT_CRPS.py DELETED
@@ -1,120 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # formats: ipynb,py:percent
5
- # text_representation:
6
- # extension: .py
7
- # format_name: percent
8
- # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
- # kernelspec:
11
- # display_name: insight
12
- # language: python
13
- # name: insight
14
- # ---
15
-
16
- # %% [markdown]
17
- # # FIGURE 3 IN THE PAPER
18
-
19
- # %% [markdown]
20
- # ## PIT AND CRPS FOR THE THREE METHODS
21
-
22
- # %% [markdown]
23
- # ### LOAD PYTHON MODULES
24
-
25
- # %%
26
- # %load_ext autoreload
27
- # %autoreload 2
28
-
29
- # %%
30
- import pandas as pd
31
- import numpy as np
32
- import os
33
- from astropy.io import fits
34
- from astropy.table import Table
35
- import torch
36
-
37
-
38
- # %%
39
- #matplotlib settings
40
- from matplotlib import rcParams
41
- import matplotlib.pyplot as plt
42
- rcParams["mathtext.fontset"] = "stix"
43
- rcParams["font.family"] = "STIXGeneral"
44
-
45
- # %%
46
- #insight modules
47
- import sys
48
- sys.path.append('../temps')
49
- #from insight_arch import EncoderPhotometry, MeasureZ
50
- #from insight import Insight_module
51
- from archive import archive
52
- from utils import nmad
53
- from plots import plot_PIT, plot_crps
54
- from temps_arch import EncoderPhotometry, MeasureZ
55
- from temps import Temps_module
56
-
57
-
58
- # %% [markdown]
59
- # ### LOAD DATA
60
-
61
- # %%
62
- photoz_archive = archive(path = parent_dir,
63
- only_zspec=False,
64
- flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
65
- target_test='L15')
66
- f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
67
-
68
-
69
- # %% [markdown]
70
- # ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
71
-
72
- # %% [markdown]
73
- # This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
74
-
75
- # %%
76
- # Initialize an empty dictionary to store DataFrames
77
- crps_dict = {}
78
- pit_dict = {}
79
- for il, lab in enumerate(['z','L15','DA']):
80
-
81
- nn_features = EncoderPhotometry()
82
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
83
- nn_z = MeasureZ(num_gauss=6)
84
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
85
-
86
- temps = Temps_module(nn_features, nn_z)
87
-
88
-
89
- pit_list = temps.pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
90
- crps_list = temps.crps(input_data=torch.Tensor(f_test), target_data=specz_test)
91
-
92
-
93
- # Assign the DataFrame to a key in the dictionary
94
- crps_dict[lab] = crps_list
95
- pit_dict[lab] = pit_list
96
-
97
-
98
- # %%
99
- plot_PIT(pit_dict['z'],
100
- pit_dict['L15'],
101
- pit_dict['DA'],
102
- labels=[r'$z_{rm s}$', 'L15', 'TEMPS'],
103
- sample='L15',
104
- save=True)
105
-
106
-
107
-
108
-
109
- # %%
110
- plot_crps(crps_dict['z'],
111
- crps_dict['L15'],
112
- crps_dict['DA'],
113
- labels=[r'$z_{\rm s}$', 'L15', 'TEMPS'],
114
- sample = 'L15',
115
- save=True)
116
-
117
-
118
-
119
-
120
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig4_pz_examples.py DELETED
@@ -1,128 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # formats: ipynb,py:percent
5
- # text_representation:
6
- # extension: .py
7
- # format_name: percent
8
- # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
- # kernelspec:
11
- # display_name: insight
12
- # language: python
13
- # name: insight
14
- # ---
15
-
16
- # %% [markdown]
17
- # # FIGURE 4 IN THE PAPER
18
-
19
- # %% [markdown]
20
- # ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
21
-
22
- # %% [markdown]
23
- # ### LOAD PYTHON MODULES
24
-
25
- # %%
26
- # %load_ext autoreload
27
- # %autoreload 2
28
-
29
- # %%
30
- import pandas as pd
31
- import numpy as np
32
- import os
33
- from astropy.io import fits
34
- from astropy.table import Table
35
- import torch
36
-
37
- # %%
38
- #matplotlib settings
39
- from matplotlib import rcParams
40
- import matplotlib.pyplot as plt
41
- rcParams["mathtext.fontset"] = "stix"
42
- rcParams["font.family"] = "STIXGeneral"
43
-
44
- # %%
45
- #insight modules
46
- import sys
47
- sys.path.append('../temps')
48
- #from insight_arch import EncoderPhotometry, MeasureZ
49
- #from insight import Insight_module
50
- from archive import archive
51
- from utils import nmad
52
- from temps_arch import EncoderPhotometry, MeasureZ
53
- from temps import Temps_module
54
-
55
-
56
- # %% [markdown]
57
- # ### LOAD DATA
58
-
59
- # %%
60
- #define here the directory containing the photometric catalogues
61
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
62
- modules_dir = '../data/models/'
63
-
64
- # %%
65
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
66
-
67
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
68
- cat = Table(hdu_list[1].data).to_pandas()
69
- cat = cat[cat['FLAG_PHOT']==0]
70
- cat = cat[cat['mu_class_L07']==1]
71
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
72
- cat = cat[cat['MAG_VIS']<25]
73
-
74
-
75
- # %%
76
- ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
77
- specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
78
- ID = cat['ID']
79
- VISmag = cat['MAG_VIS']
80
- zsflag = cat['reliable_S15']
81
-
82
- # %%
83
- photoz_archive = archive(path = parent_dir,only_zspec=False)
84
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
85
- col, colerr = photoz_archive._to_colors(f, ferr)
86
-
87
- # %% [markdown]
88
- # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
89
-
90
- # %% [markdown]
91
- # The notebook 'Tutorial_temps' gives an example of how to train and save models.
92
-
93
- # %%
94
- # Initialize an empty dictionary to store DataFrames
95
- ii = np.random.randint(0,len(col),1)
96
- pz_dict = {}
97
- for il, lab in enumerate(['z','L15','DA']):
98
-
99
- nn_features = EncoderPhotometry()
100
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
101
- nn_z = MeasureZ(num_gauss=6)
102
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
103
-
104
- temps = Temps_module(nn_features, nn_z)
105
-
106
-
107
- z,zerr, pz, flag,_ = temps.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)
108
-
109
-
110
- # Assign the DataFrame to a key in the dictionary
111
- pz_dict[lab] = pz
112
-
113
-
114
- # %%
115
- cmap = plt.get_cmap('Dark2')
116
-
117
- plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
118
- plt.plot(np.linspace(0,5,1000),pz_dict['L15'][0],label='L15', color = cmap(1), ls =':')
119
- plt.plot(np.linspace(0,5,1000),pz_dict['DA'][0],label='TEMPS', color = cmap(2), ls ='-')
120
- plt.axvline(x=np.array(ztarget)[ii][0],ls='-.',color='black')
121
- #plt.xlim(0,2)
122
- plt.legend()
123
-
124
- plt.xlabel(r'$z$', fontsize=14)
125
- plt.ylabel('Probability', fontsize=14)
126
- #plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
127
-
128
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig6_qualitycut.py DELETED
@@ -1,164 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
- # kernelspec:
10
- # display_name: insight
11
- # language: python
12
- # name: insight
13
- # ---
14
-
15
- # # FIGURE 6 IN THE PAPER
16
-
17
- # ## QUALITY CUTS
18
-
19
- # %load_ext autoreload
20
- # %autoreload 2
21
-
22
- import pandas as pd
23
- import numpy as np
24
- import os
25
- import torch
26
- from scipy import stats
27
-
28
- #matplotlib settings
29
- from matplotlib import rcParams
30
- import matplotlib.pyplot as plt
31
- rcParams["mathtext.fontset"] = "stix"
32
- rcParams["font.family"] = "STIXGeneral"
33
-
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
- #from insight_arch import EncoderPhotometry, MeasureZ
38
- #from insight import Insight_module
39
- from archive import archive
40
- from utils import nmad
41
- from temps_arch import EncoderPhotometry, MeasureZ
42
- from temps import Temps_module
43
-
44
-
45
- # ### LOAD DATA (ONLY SPECZ)
46
-
47
- #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
-
51
- photoz_archive = archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
52
- f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
53
-
54
-
55
- # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
56
-
57
- # +
58
- # Initialize an empty dictionary to store DataFrames
59
- dfs = {}
60
-
61
- for il, lab in enumerate(['z','L15','DA']):
62
-
63
- nn_features = EncoderPhotometry()
64
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
65
- nn_z = MeasureZ(num_gauss=6)
66
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
67
-
68
- temps = Temps_module(nn_features, nn_z)
69
-
70
- z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(f_test_specz),
71
- return_pz=True)
72
-
73
-
74
- # Create a DataFrame with the desired columns
75
- df = pd.DataFrame(np.c_[z, flag, odds, specz_test],
76
- columns=['z','zflag', 'odds' ,'ztarget'])
77
-
78
- # Calculate additional columns or operations if needed
79
- df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
80
-
81
- # Drop any rows with NaN values
82
- df = df.dropna()
83
-
84
- # Assign the DataFrame to a key in the dictionary
85
- dfs[lab] = df
86
-
87
- # -
88
-
89
- # ### STATISTICS BASED ON OUR QUALITY CUT
90
-
91
- # +
92
- bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.05))
93
- scatter, eta, xlab, xmag, xzs, flagmean = [],[],[], [], [], []
94
-
95
- for k in range(len(bin_edges)-1):
96
- edge_min = bin_edges[k]
97
- edge_max = bin_edges[k+1]
98
-
99
- df_bin = df[(df.zflag > edge_min)]
100
-
101
-
102
- xlab.append(np.round(len(df_bin)/len(df),2)*100)
103
- xzs.append(0.5*(df_bin.ztarget.min()+df_bin.ztarget.max()))
104
- flagmean.append(np.mean(df_bin.zflag))
105
- scatter.append(nmad(df_bin.zwerr))
106
- eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
107
-
108
-
109
- # -
110
-
111
- # ### STATISTICS BASED ON ODDS
112
-
113
- # +
114
- bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0,1.01,0.05))
115
- scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
116
-
117
- for k in range(len(bin_edges)-1):
118
- edge_min = bin_edges[k]
119
- edge_max = bin_edges[k+1]
120
-
121
- df_bin = df[(df.odds > edge_min)]
122
-
123
-
124
- xlab_odds.append(np.round(len(df_bin)/len(df),2)*100)
125
- oddsmean.append(np.mean(df_bin.zflag))
126
- scatter_odds.append(nmad(df_bin.zwerr))
127
- eta_odds.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
128
-
129
-
130
- # -
131
-
132
- # ### PLOTS
133
-
134
- # +
135
- plt.plot(xlab_odds,scatter_odds, marker = '.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
136
- plt.plot(xlab,scatter, marker = '.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
137
-
138
-
139
- plt.ylabel(r'NMAD [$\Delta z\ /\ (1 + z_{\rm s})$]', fontsize=16)
140
- plt.xlabel('Completeness', fontsize=16)
141
-
142
- plt.yticks(fontsize=12)
143
- plt.xticks(np.arange(5,101,10), fontsize=12)
144
- plt.legend(fontsize=14)
145
-
146
- plt.savefig('Flag_nmad_zspec.pdf', bbox_inches='tight')
147
- plt.show()
148
-
149
- # +
150
- plt.plot(xlab_odds,eta_odds, marker='.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
151
- plt.plot(xlab,eta, marker='.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
152
-
153
- plt.yticks(fontsize=12)
154
- plt.xticks(np.arange(5,101,10), fontsize=12)
155
- plt.ylabel(r'$\eta$ [%]', fontsize=16)
156
- plt.xlabel('Completeness', fontsize=16)
157
- plt.legend()
158
-
159
- plt.savefig('Flag_eta_zspec.pdf', bbox_inches='tight')
160
-
161
- plt.show()
162
- # -
163
-
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig7_colourspace.py DELETED
@@ -1,261 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
- # kernelspec:
10
- # display_name: insight
11
- # language: python
12
- # name: insight
13
- # ---
14
-
15
- # # FIGURE COLOURSPACE IN THE PAPER
16
-
17
- # %load_ext autoreload
18
- # %autoreload 2
19
-
20
- import pandas as pd
21
- import numpy as np
22
- import os
23
- from astropy.io import fits
24
- from astropy.table import Table
25
- import torch
26
-
27
- #matplotlib settings
28
- from matplotlib import rcParams
29
- import matplotlib.pyplot as plt
30
- rcParams["mathtext.fontset"] = "stix"
31
- rcParams["font.family"] = "STIXGeneral"
32
-
33
- # +
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
-
38
- from archive import archive
39
- from utils import nmad
40
- from temps_arch import EncoderPhotometry, MeasureZ
41
- from temps import Temps_module
42
- from plots import plot_nz
43
-
44
-
45
- # -
46
-
47
- def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
48
- """
49
- Estimate a Self-Organizing Map (SOM) visualization from a DataFrame.
50
-
51
- Parameters:
52
- - df (pandas.DataFrame): Input DataFrame containing data for SOM estimation.
53
- - plot_arg (str, optional): Column name to be used for plotting. Default is 'z'.
54
- - nx (int, optional): Number of cells along the X-axis. Default is 40.
55
- - ny (int, optional): Number of cells along the Y-axis. Default is 40.
56
-
57
- Returns:
58
- - som_data (numpy.ndarray): Estimated SOM visualization data.
59
- """
60
- x_cells = np.arange(0, nx)
61
- y_cells = np.arange(0, ny)
62
- index_cell = np.arange(nx * ny)
63
- cells = np.array(np.meshgrid(x_cells, y_cells)).T.reshape(-1, 2)
64
- cells = pd.DataFrame(np.c_[cells[:, 0], cells[:, 1], index_cell], columns=['x_cell', 'y_cell', 'cell'])
65
-
66
- if plot_arg == 'count':
67
- som_vis = df.groupby('cell')['z'].count().reset_index().rename(columns={f'z': 'plot_som'})
68
- else:
69
- som_vis = df.groupby('cell')[f'{plot_arg}'].mean().reset_index().rename(columns={f'{plot_arg}': 'plot_som'})
70
-
71
- som_data = som_vis.merge(cells, on='cell')
72
- som_data = som_data.pivot(index='x_cell', columns='y_cell', values='plot_som')
73
-
74
- return som_data
75
-
76
-
77
-
78
- def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
79
- """
80
- Plot the Self-Organizing Map (SOM) data.
81
-
82
- Parameters:
83
- - som_data (numpy.ndarray): The SOM data to be visualized.
84
- - plot_arg (str, optional): The column name to be plotted. Default is 'z'.
85
- - vmin (float, optional): Minimum value for color scaling. Default is 0.
86
- - vmax (float, optional): Maximum value for color scaling. Default is 1.
87
-
88
- Returns:
89
- None
90
- """
91
- plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap='viridis') # Choose an appropriate colormap
92
- plt.colorbar(label=f'{plot_arg}') # Add a colorbar with a label
93
- plt.xlabel(r'$x$ [pixel]', fontsize=14) # Add an appropriate X-axis label
94
- plt.ylabel(r'$y$ [pixel]', fontsize=14) # Add an appropriate Y-axis label
95
- plt.show()
96
-
97
-
98
-
99
- # ### LOAD DATA
100
-
101
- # +
102
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
103
-
104
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
105
- cat = Table(hdu_list[1].data).to_pandas()
106
- cat = cat[cat['FLAG_PHOT']==0]
107
- cat = cat[cat['mu_class_L07']==1]
108
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
109
- cat = cat[cat['MAG_VIS']<25]
110
-
111
- # -
112
-
113
- ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
114
- specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
115
- ID = cat['ID']
116
- VISmag = cat['MAG_VIS']
117
- zsflag = cat['reliable_S15']
118
-
119
- photoz_archive = archive(path = parent_dir,only_zspec=False)
120
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
121
- col, colerr = photoz_archive._to_colors(f, ferr)
122
-
123
- # +
124
- dfs = {}
125
-
126
- for il, lab in enumerate(['z','L15','DA']):
127
-
128
- nn_features = EncoderPhotometry()
129
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
130
- nn_z = MeasureZ(num_gauss=6)
131
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
132
-
133
- temps = Temps_module(nn_features, nn_z)
134
-
135
- z,zerr ,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
136
- return_pz=True)
137
- # Create a DataFrame with the desired columns
138
- df = pd.DataFrame(np.c_[ID, VISmag,z, flag, ztarget,zsflag,zerr, specz_or_photo],
139
- columns=['ID','VISmag','z','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
140
-
141
- # Calculate additional columns or operations if needed
142
- df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
143
-
144
- # Drop any rows with NaN values
145
- df = df.dropna()
146
-
147
- # Assign the DataFrame to a key in the dictionary
148
- dfs[lab] = df
149
-
150
- # -
151
-
152
- # ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
153
-
154
- #define here the directory containing the photometric catalogues
155
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
156
- modules_dir = '../data/models/'
157
-
158
- df_z = dfs['z']
159
- df_z_DA = dfs['DA']
160
-
161
- # ##### LOAD TRAIN SOM ON TRAINING DATA
162
-
163
- df_som = pd.read_csv(os.path.join(parent_dir,'som_dataframe.csv'), header = 0, sep =',')
164
- df_z = df_z.merge(df_som, on = 'ID')
165
- df_z_DA = df_z_DA.merge(df_som, on = 'ID')
166
-
167
- # ##### APPLY CUTS FOR DIFFERENT SAMPLES
168
-
169
- df_zspec = df_z[(df_z.S15_L15_flag==0) & (df_z.zsflag==1)]
170
- df_l15 = df_z[(df_z.ztarget>0)]
171
- df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
172
-
173
- df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
174
- df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.zflag>0.033]
175
-
176
- df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
177
- df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.zflag>0.018]
178
-
179
- # ## MAKE SOM PLOT
180
-
181
- from mpl_toolkits.axes_grid1 import make_axes_locatable
182
-
183
- # +
184
- fig, axs = plt.subplots(6, 4, figsize=(13, 15), sharex=True, sharey=True, gridspec_kw={'hspace': 0.05, 'wspace': 0.06})
185
-
186
- # Plot in the top row (axs[0, i])
187
- #top row, spectroscopic sample
188
- columns = ['ztarget','z','zwerr','count']
189
- titles = [r'$z_{true}$',r'$z$',r'$z_{\rm error}$','Counts']
190
- limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
191
- for ii in range(4):
192
- som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
193
- im = axs[0,ii].imshow(som_data, vmin=limits[ii][0], vmax=limits[ii][1], cmap='viridis') # Choose an appropriate colormap
194
- axs[0, ii].set_title(f'{titles[ii]}', fontsize=18)
195
-
196
- if ii==0:
197
- axs[0, 0].set_ylabel(r'$y$', fontsize=14)
198
- elif ii==1:
199
- cbar_ax = fig.add_axes([0.49, 0.11, 0.01, 0.77])
200
- fig.colorbar(im, cax=cbar_ax)
201
- elif ii==2:
202
- cbar_ax = fig.add_axes([0.685, 0.11, 0.01, 0.77])
203
- fig.colorbar(im, cax=cbar_ax)
204
- elif ii==3:
205
- cbar_ax = fig.add_axes([0.885, 0.11, 0.01, 0.77])
206
- fig.colorbar(im, cax=cbar_ax)
207
-
208
- for jj in range(4):
209
- som_data = estimate_som_map(df_l15, plot_arg=columns[jj], nx=40, ny=40)
210
- im = axs[1,jj].imshow(som_data, vmin=limits[jj][0], vmax=limits[jj][1], cmap='viridis') # Choose an appropriate colormap
211
- #axs[1, jj].set_title(f'{titles[jj]}', fontsize=14)
212
- #axs[1, jj].set_xlabel(r'$x$', fontsize=14)
213
-
214
-
215
- for kk in range(4):
216
- som_data = estimate_som_map(df_l15_DA, plot_arg=columns[kk], nx=40, ny=40)
217
- im = axs[2,kk].imshow(som_data, vmin=limits[kk][0], vmax=limits[kk][1], cmap='viridis') # Choose an appropriate colormap
218
- #axs[2, kk].set_title(f'{titles[kk]}', fontsize=14)
219
- #axs[2, kk].set_xlabel(r'$x$', fontsize=14)
220
-
221
- for rr in range(4):
222
- som_data = estimate_som_map(df_l15_euclid_da, plot_arg=columns[rr], nx=40, ny=40)
223
- im = axs[3,rr].imshow(som_data, vmin=limits[rr][0], vmax=limits[rr][1], cmap='viridis') # Choose an appropriate colormap
224
- #axs[3, rr].set_title(f'{titles[rr]}', fontsize=14)
225
- #axs[3, rr].set_xlabel(r'$x$', fontsize=14)
226
-
227
- for ll in range(4):
228
- som_data = estimate_som_map(df_l15_euclid_cut, plot_arg=columns[ll], nx=40, ny=40)
229
- im = axs[4,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
230
- #axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
231
- axs[4, ll].set_xlabel(r'$x$', fontsize=14)
232
-
233
- for ll in range(4):
234
- som_data = estimate_som_map(df_l15_euclid_cut_da, plot_arg=columns[ll], nx=40, ny=40)
235
- im = axs[5,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
236
- #axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
237
- axs[5, ll].set_xlabel(r'$x$', fontsize=14)
238
-
239
-
240
- axs[0, 0].set_ylabel(r'$y$', fontsize=14)
241
- axs[1, 0].set_ylabel(r'$y$', fontsize=14)
242
- axs[2, 0].set_ylabel(r'$y$', fontsize=14)
243
- axs[3, 0].set_ylabel(r'$y$', fontsize=14)
244
- axs[4, 0].set_ylabel(r'$y$', fontsize=14)
245
- axs[5, 0].set_ylabel(r'$y$', fontsize=14)
246
-
247
-
248
- fig.text(0.09, 0.815, r'$z_{\rm s}$ sample', va='center', rotation='vertical', fontsize=16)
249
- fig.text(0.09, 0.69, r'L15 sample', va='center', rotation='vertical', fontsize=16)
250
- fig.text(0.09, 0.56, r'L15 sample + DA', va='center', rotation='vertical', fontsize=14)
251
- fig.text(0.09, 0.44, r'$Euclid$ sample + DA', va='center', rotation='vertical', fontsize=14)
252
- fig.text(0.09, 0.3, r'$Euclid$ sample + QC', va='center', rotation='vertical', fontsize=14)
253
-
254
- fig.text(0.09, 0.17, r'$Euclid$ sample + DA + QC', va='center', rotation='vertical', fontsize=13)
255
-
256
-
257
- plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
258
-
259
- # -
260
-
261
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Table_metrics.py DELETED
@@ -1,148 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
- # kernelspec:
10
- # display_name: insight
11
- # language: python
12
- # name: insight
13
- # ---
14
-
15
- # # TABLE METRICS
16
-
17
- # %load_ext autoreload
18
- # %autoreload 2
19
-
20
- import pandas as pd
21
- import numpy as np
22
- import os
23
- import torch
24
- from scipy import stats
25
- from astropy.io import fits
26
- from astropy.table import Table
27
-
28
- #matplotlib settings
29
- from matplotlib import rcParams
30
- import matplotlib.pyplot as plt
31
- rcParams["mathtext.fontset"] = "stix"
32
- rcParams["font.family"] = "STIXGeneral"
33
-
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
- #from insight_arch import EncoderPhotometry, MeasureZ
38
- #from insight import Insight_module
39
- from archive import archive
40
- from utils import nmad, select_cut
41
- from temps_arch import EncoderPhotometry, MeasureZ
42
- from temps import Temps_module
43
-
44
-
45
- # ## LOAD DATA
46
-
47
- #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
-
51
- # +
52
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
53
-
54
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
55
- cat = Table(hdu_list[1].data).to_pandas()
56
- cat = cat[cat['FLAG_PHOT']==0]
57
- cat = cat[cat['mu_class_L07']==1]
58
-
59
- cat['SNR_VIS'] = cat.FLUX_VIS / cat.FLUXERR_VIS
60
- # -
61
-
62
- cat = cat[cat.SNR_VIS>10]
63
-
64
- ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
65
- specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
66
- ID = cat['ID']
67
- VISmag = cat['MAG_VIS']
68
- zsflag = cat['reliable_S15']
69
-
70
- cat['ztarget']=ztarget
71
- cat['specz_or_photo']=specz_or_photo
72
-
73
- cat = cat[cat.ztarget>0]
74
-
75
- # ### EXTRACT PHOTOMETRY
76
-
77
- photoz_archive = archive(path = parent_dir,only_zspec=False)
78
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
79
- col, colerr = photoz_archive._to_colors(f, ferr)
80
-
81
- # ### MEASURE CATALOGUE
82
-
83
- # +
84
- # Initialize an empty dictionary to store DataFrames
85
- lab='DA'
86
- nn_features = EncoderPhotometry()
87
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
88
- nn_z = MeasureZ(num_gauss=6)
89
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
90
-
91
- temps = Temps_module(nn_features, nn_z)
92
-
93
- z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
94
- return_pz=True)
95
-
96
-
97
- # Create a DataFrame with the desired columns
98
- df = pd.DataFrame(np.c_[z, flag, odds, cat.ztarget, cat.reliable_S15, cat.specz_or_photo],
99
- columns=['z','zflag', 'odds' ,'ztarget','reliable_S15', 'specz_or_photo'])
100
-
101
- # Calculate additional columns or operations if needed
102
- df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
103
-
104
- # Drop any rows with NaN values
105
- df = df.dropna()
106
-
107
-
108
- # -
109
-
110
- # ### SPECZ SAMPLE
111
-
112
- df_specz = df[(df.reliable_S15==1)&(df.specz_or_photo==0)]
113
-
114
- # +
115
- df_selected, cut, dfcuts = select_cut(df_specz,
116
- completenss_lim=None,
117
- nmad_lim=0.055,
118
- outliers_lim=None,
119
- return_df=True)
120
-
121
-
122
- # -
123
-
124
- print(dfcuts.to_latex(float_format="%.3f",
125
- columns=['Nobj','completeness', 'nmad', 'eta'],
126
- index=False
127
- ))
128
-
129
- # ### EUCLID SAMPLE
130
-
131
- df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
132
-
133
- # +
134
- df_selected, cut, dfcuts = select_cut(df_euclid,
135
- completenss_lim=None,
136
- nmad_lim=0.055,
137
- outliers_lim=None,
138
- return_df=True)
139
-
140
-
141
- # -
142
-
143
- print(dfcuts.to_latex(float_format="%.3f",
144
- columns=['Nobj','completeness', 'nmad', 'eta'],
145
- index=False
146
- ))
147
-
148
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
temps/archive.py CHANGED
@@ -1,18 +1,18 @@
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
4
- import os
5
  from astropy.table import Table
6
  from scipy.spatial import KDTree
 
 
 
 
7
 
8
- import matplotlib.pyplot as plt
9
 
10
- from matplotlib import rcParams
11
  rcParams["mathtext.fontset"] = "stix"
12
  rcParams["font.family"] = "STIXGeneral"
13
 
14
-
15
- class archive():
16
  def __init__(self, path,
17
  aperture=2,
18
  drop_stars=True,
@@ -21,31 +21,42 @@ class archive():
21
  extinction_corr=True,
22
  only_zspec=True,
23
  all_apertures=False,
24
- target_test='specz', flags_kept=[3,3.1,3.4,3.5,4]):
 
25
 
 
26
  self.aperture = aperture
27
- self.all_apertures=all_apertures
28
- self.flags_kept=flags_kept
29
 
 
 
30
 
 
 
 
31
 
32
- filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
33
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
 
 
34
 
35
- hdu_list = fits.open(os.path.join(path,filename_calib))
36
- cat = Table(hdu_list[1].data).to_pandas()
37
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
38
-
39
 
40
- hdu_list = fits.open(os.path.join(path,filename_valid))
41
- cat_test = Table(hdu_list[1].data).to_pandas()
 
42
 
43
 
44
  if drop_stars==True:
 
45
  cat = cat[cat.mu_class_L07==1]
46
  cat_test = cat_test[cat_test.mu_class_L07==1]
47
 
48
  if clean_photometry==True:
 
49
  cat = self._clean_photometry(cat)
50
  cat_test = self._clean_photometry(cat_test)
51
 
@@ -216,9 +227,11 @@ class archive():
216
 
217
 
218
  if only_zspec:
 
219
  catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
220
  catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
221
  else:
 
222
  catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
223
 
224
 
@@ -233,9 +246,11 @@ class archive():
233
 
234
 
235
  if extinction_corr==True:
 
236
  f = self._correct_extinction(catalogue,f)
237
 
238
  if convert_colors==True:
 
239
  col, colerr = self._to_colors(f, ferr)
240
  col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
241
 
 
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
 
4
  from astropy.table import Table
5
  from scipy.spatial import KDTree
6
+ from matplotlib import pyplot as plt
7
+ from matplotlib import rcParams
8
+ from pathlib import Path
9
+ from loguru import logger
10
 
 
11
 
 
12
  rcParams["mathtext.fontset"] = "stix"
13
  rcParams["font.family"] = "STIXGeneral"
14
 
15
+ class Archive:
 
16
  def __init__(self, path,
17
  aperture=2,
18
  drop_stars=True,
 
21
  extinction_corr=True,
22
  only_zspec=True,
23
  all_apertures=False,
24
+ target_test='specz', flags_kept=[3, 3.1, 3.4, 3.5, 4]):
25
+
26
 
27
+ logger.info("Starting archive")
28
  self.aperture = aperture
29
+ self.all_apertures = all_apertures
30
+ self.flags_kept = flags_kept
31
 
32
+ filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
33
+ filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
34
 
35
+ # Use Path for file handling
36
+ path_calib = Path(path) / filename_calib
37
+ path_valid = Path(path) / filename_valid
38
 
39
+ # Open the calibration FITS file
40
+ with fits.open(path_calib) as hdu_list:
41
+ cat = Table(hdu_list[1].data).to_pandas()
42
+ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
43
 
44
+ # Open the validation FITS file
45
+ with fits.open(path_valid) as hdu_list:
46
+ cat_test = Table(hdu_list[1].data).to_pandas()
 
47
 
48
+ # Store the catalogs for later use
49
+ self.cat = cat
50
+ self.cat_test = cat_test
51
 
52
 
53
  if drop_stars==True:
54
+ logger.info("dropping stars...")
55
  cat = cat[cat.mu_class_L07==1]
56
  cat_test = cat_test[cat_test.mu_class_L07==1]
57
 
58
  if clean_photometry==True:
59
+ logger.info("cleaning stars...")
60
  cat = self._clean_photometry(cat)
61
  cat_test = self._clean_photometry(cat_test)
62
 
 
227
 
228
 
229
  if only_zspec:
230
+ logger.info("Selecting only galaxies with spectroscopic redshift")
231
  catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
232
  catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
233
  else:
234
+ logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
235
  catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
236
 
237
 
 
246
 
247
 
248
  if extinction_corr==True:
249
+ logger.info("Correcting MW extinction")
250
  f = self._correct_extinction(catalogue,f)
251
 
252
  if convert_colors==True:
253
+ logger.info("Converting to colors")
254
  col, colerr = self._to_colors(f, ferr)
255
  col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
256
 
temps/plots.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
- from utils import nmad
5
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
@@ -181,68 +181,7 @@ def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', lab
181
  # Show the plot
182
  plt.show()
183
 
184
-
185
- import numpy as np
186
- import matplotlib.pyplot as plt
187
- from scipy import stats
188
-
189
- def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
190
- #plot properties
191
- plt.rcParams['font.family'] = 'serif'
192
- plt.rcParams['font.size'] = 12
193
-
194
-
195
-
196
-
197
- bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
198
- print(bin_edges)
199
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
200
- plt.figure(figsize=(6, 5))
201
- ls = ['--',':','-']
202
-
203
- for i, df in enumerate(df_list):
204
- ydata, xlab = [], []
205
-
206
- for k in range(len(bin_edges)-1):
207
- edge_min = bin_edges[k]
208
- edge_max = bin_edges[k+1]
209
-
210
- mean_mag = (edge_max + edge_min) / 2
211
-
212
- if type_bin == 'bin':
213
- df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
214
- elif type_bin == 'cum':
215
- df_plot = df[(df[xvariable] < edge_max)]
216
- else:
217
- raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
218
-
219
- xlab.append(mean_mag)
220
- if metric == 'sig68':
221
- ydata.append(sigma68(df_plot.zwerr))
222
- elif metric == 'bias':
223
- ydata.append(np.mean(df_plot.zwerr))
224
- elif metric == 'nmad':
225
- ydata.append(nmad(df_plot.zwerr))
226
- elif metric == 'outliers':
227
- ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
228
-
229
- print(ydata)
230
- color = cmap(i) # Get a different color for each dataframe
231
- plt.plot(xlab, ydata,marker='.', lw=1, label=f'{label_list[i]}', color=color, ls=ls[i])
232
-
233
- if xvariable == 'VISmag':
234
- xvariable_lab = 'VIS'
235
-
236
-
237
 
238
- plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
239
- plt.xlabel(f'{xvariable_lab}', fontsize=16)
240
- plt.grid(False)
241
- plt.legend()
242
-
243
- if save==True:
244
- plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
245
- plt.show()
246
 
247
 
248
  def plot_nz(df_list,
@@ -336,3 +275,43 @@ def plot_crps(crps_list_1, crps_list_2 = None, crps_list_3=None, labels=None, s
336
  # Show the plot
337
  plt.show()
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ from temps.utils import nmad
5
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
 
181
  # Show the plot
182
  plt.show()
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  def plot_nz(df_list,
 
275
  # Show the plot
276
  plt.show()
277
 
278
+
279
+
280
+ def plot_nz(df, bins=np.arange(0,5,0.2)):
281
+ kwargs=dict( bins=bins,alpha=0.5)
282
+ plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
283
+ counts, _, =np.histogram(df.z.values, bins=bins)
284
+
285
+ plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
286
+
287
+ #plt.legend(fontsize=14)
288
+ plt.xlabel(r'Redshift', fontsize=14)
289
+ plt.ylabel(r'Counts', fontsize=14)
290
+ plt.yscale('log')
291
+
292
+ plt.show()
293
+
294
+ return
295
+
296
+
297
+ def plot_scatter(df, sample='specz', save=True):
298
+ # Calculate the point density
299
+ xy = np.vstack([df.zs.values,df.z.values])
300
+ zd = gaussian_kde(xy)(xy)
301
+
302
+ fig, ax = plt.subplots()
303
+ plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
304
+ plt.xlim(0,5)
305
+ plt.ylim(0,5)
306
+
307
+ plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
308
+ plt.ylabel('$z$', fontsize = 14)
309
+
310
+ plt.xticks(fontsize = 12)
311
+ plt.yticks(fontsize = 12)
312
+
313
+ if save==True:
314
+ plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
315
+
316
+ plt.show()
317
+
temps/temps.py CHANGED
@@ -1,249 +1,267 @@
 
 
1
  import torch
2
- from torch.utils.data import DataLoader, dataset, TensorDataset
3
  from torch import nn, optim
 
4
  from torch.optim import lr_scheduler
5
- import numpy as np
6
- import pandas as pd
7
- from astropy.io import fits
8
- import os
9
- from astropy.table import Table
10
- from scipy.spatial import KDTree
11
- from scipy.special import erf
12
  from scipy.stats import norm
13
- import sys
14
-
15
- sys.path.append('/.')
16
- from utils import maximum_mean_discrepancy, compute_kernel
17
-
18
- class Temps_module():
19
- """ Define class"""
20
-
21
- def __init__(self, modelF, modelZ, batch_size=100,rejection_param=1, da=True, verbose=False):
22
- self.modelZ=modelZ
23
- self.modelF=modelF
24
- self.da=da
25
- self.verbose=verbose
26
- self.ngaussians=modelZ.ngaussians
27
-
28
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
- self.batch_size=batch_size
30
- self.rejection_parameter=rejection_param
31
-
32
-
33
-
34
- def _get_dataloaders(self, input_data, target_data, input_data_DA, val_fraction=0.1):
 
 
 
 
 
 
 
 
 
 
 
35
  input_data = torch.Tensor(input_data)
36
  target_data = torch.Tensor(target_data)
37
- if input_data_DA is not None:
38
- input_data_DA = torch.Tensor(input_data_DA)
39
- else:
40
- input_data_DA = input_data.clone()
41
-
42
- dataset = TensorDataset(input_data, input_data_DA, target_data)
43
- trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
44
- loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
45
- loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
 
 
 
 
 
 
46
 
47
  return loader_train, loader_val
48
 
49
-
50
-
51
-
52
- def _loss_function(self,mean, std, logmix, true):
53
-
54
- log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
55
- log_prob = torch.logsumexp(log_prob, 1)
56
  loss = -log_prob.mean()
57
-
58
- return loss
59
-
60
- def _loss_function_DA(self,f1, f2):
61
- kl_loss = nn.KLDivLoss(reduction="batchmean",log_target=True)
62
  loss = kl_loss(f1, f2)
63
- loss = torch.log(loss)
64
- #print('f1',f1)
65
- #print('f2',f2)
66
-
67
- return loss
68
 
69
- def _to_numpy(self,x):
 
70
  return x.detach().cpu().numpy()
71
-
72
-
73
-
74
- def train(self,input_data,
75
- input_data_DA,
76
- target_data,
77
- nepochs=10,
78
- step_size = 100,
79
- val_fraction=0.1,
80
- lr=1e-3,
81
- weight_decay=0):
82
- self.modelZ = self.modelZ.train()
83
- self.modelF = self.modelF.train()
84
-
85
- loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, val_fraction=0.1)
86
- optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
87
- optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
88
-
89
- schedulerZ = torch.optim.lr_scheduler.StepLR(optimizerZ, step_size=step_size, gamma =0.1)
90
- schedulerF = torch.optim.lr_scheduler.StepLR(optimizerF, step_size=step_size, gamma =0.1)
91
-
92
- self.modelZ = self.modelZ.to(self.device)
93
- self.modelF = self.modelF.to(self.device)
94
 
95
- self.loss_train, self.loss_validation = [],[]
96
-
97
- for epoch in range(nepochs):
98
- for input_data, input_data_da, target_data in loader_train:
99
- _loss_train, _loss_validation = [],[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- input_data = input_data.to(self.device)
102
- target_data = target_data.to(self.device)
103
-
 
 
 
 
 
 
104
  if self.da:
105
  input_data_da = input_data_da.to(self.device)
106
 
107
-
108
- optimizerF.zero_grad()
109
- optimizerZ.zero_grad()
110
 
111
- features = self.modelF(input_data)
112
- if self.da:
113
- features_DA = self.modelF(input_data_da)
114
 
115
- mu, logsig, logmix_coeff = self.modelZ(features)
116
- logsig = torch.clamp(logsig,-6,2)
117
  sig = torch.exp(logsig)
118
 
119
- lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
121
 
122
- if self.da:
123
- lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
124
- lossDA = lossDA.sum()
125
- loss = lossZ +1e3*lossDA
126
- else:
127
- loss = lossZ
128
-
129
- _loss_train.append(lossZ.item())
130
-
131
- loss.backward()
132
- optimizerF.step()
133
- optimizerZ.step()
134
-
135
- schedulerF.step()
136
- schedulerZ.step()
137
-
138
- self.loss_train.append(np.mean(_loss_train))
139
 
140
- for input_data, _, target_data in loader_val:
 
 
 
 
141
 
 
 
 
 
142
  input_data = input_data.to(self.device)
143
  target_data = target_data.to(self.device)
144
 
145
-
146
- features = self.modelF(input_data)
147
- mu, logsig, logmix_coeff = self.modelZ(features)
148
-
149
- logsig = torch.clamp(logsig,-6,2)
150
  sig = torch.exp(logsig)
151
 
152
  loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
153
  _loss_validation.append(loss_val.item())
154
 
155
- self.loss_validation.append(np.mean(_loss_validation))
156
-
157
-
158
- if self.verbose:
159
-
160
- print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
161
-
162
 
163
  def get_features(self, input_data):
164
- self.modelF = self.modelF.eval()
165
- self.modelF = self.modelF.to(self.device)
166
-
167
  input_data = input_data.to(self.device)
168
-
169
- features = self.modelF(input_data)
170
-
171
- return features.detach().cpu().numpy()
172
-
173
 
174
- def get_pz(self,input_data, return_pz=True, return_flag=True, retrun_odds=False):
175
- self.modelZ = self.modelZ.eval()
176
- self.modelZ = self.modelZ.to(self.device)
177
- self.modelF = self.modelF.eval()
178
- self.modelF = self.modelF.to(self.device)
179
 
180
  input_data = input_data.to(self.device)
181
-
182
-
183
- features = self.modelF(input_data)
184
- mu, logsig, logmix_coeff = self.modelZ(features)
185
- logsig = torch.clamp(logsig,-6,2)
186
  sig = torch.exp(logsig)
187
 
188
  mix_coeff = torch.exp(logmix_coeff)
 
 
 
 
 
189
 
190
- z = (mix_coeff * mu).sum(1)
191
- zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - mu.mean(1)[:,None])**2).sum(1))
192
-
193
- mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
194
-
195
-
196
- if return_pz==True:
197
- zgrid = np.linspace(0, 5, 1000)
198
- pdf_mixture = np.zeros(shape=(len(input_data), len(zgrid)))
199
- for ii in range(len(input_data)):
200
- for i in range(self.ngaussians):
201
- pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(zgrid, mu[ii,i], sig[ii,i])
202
- if return_flag==True:
203
- #narrow peak
204
- pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
205
- diff_matrix = np.abs(self._to_numpy(z)[:,None] - zgrid[None,:])
206
- #odds
207
- idx_peak = np.argmax(pdf_mixture,1)
208
- zpeak = zgrid[idx_peak]
209
- diff_matrix_upper = np.abs((zpeak+0.05)[:,None] - zgrid[None,:])
210
- diff_matrix_lower = np.abs((zpeak-0.05)[:,None] - zgrid[None,:])
211
-
212
- idx = np.argmin(diff_matrix,1)
213
- idx_upper = np.argmin(diff_matrix_upper,1)
214
- idx_lower = np.argmin(diff_matrix_lower,1)
215
-
216
- p_z_x = np.zeros(shape=(len(z)))
217
- odds = np.zeros(shape=(len(z)))
218
-
219
- for ii in range(len(z)):
220
- p_z_x[ii] = pdf_mixture[ii,idx[ii]]
221
- odds[ii] = pdf_mixture[ii,:idx_upper[ii]].sum() - pdf_mixture[ii,:idx_lower[ii]].sum()
222
-
223
-
224
-
225
- return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture, p_z_x, odds
226
- else:
227
 
228
- return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture
229
-
 
230
  else:
231
- return self._to_numpy(z),self._to_numpy(zerr)
232
-
233
- def pit(self, input_data, target_data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  pit_list = []
236
 
237
- self.modelF = self.modelF.eval()
238
- self.modelF = self.modelF.to(self.device)
239
- self.modelZ = self.modelZ.eval()
240
- self.modelZ = self.modelZ.to(self.device)
241
 
242
  input_data = input_data.to(self.device)
243
 
244
 
245
- features = self.modelF(input_data)
246
- mu, logsig, logmix_coeff = self.modelZ(features)
247
 
248
  logsig = torch.clamp(logsig,-6,2)
249
  sig = torch.exp(logsig)
@@ -259,7 +277,8 @@ class Temps_module():
259
 
260
  return pit_list
261
 
262
- def crps(self, input_data, target_data):
 
263
 
264
  def measure_crps(cdf, t):
265
  zgrid = np.linspace(0,4,1000)
@@ -273,16 +292,16 @@ class Temps_module():
273
 
274
  crps_list = []
275
 
276
- self.modelF = self.modelF.eval()
277
- self.modelF = self.modelF.to(self.device)
278
- self.modelZ = self.modelZ.eval()
279
- self.modelZ = self.modelZ.to(self.device)
280
 
281
  input_data = input_data.to(self.device)
282
 
283
 
284
- features = self.modelF(input_data)
285
- mu, logsig, logmix_coeff = self.modelZ(features)
286
  logsig = torch.clamp(logsig,-6,2)
287
  sig = torch.exp(logsig)
288
 
@@ -294,21 +313,19 @@ class Temps_module():
294
  z = (mix_coeff * mu).sum(1)
295
 
296
  x = np.linspace(0, 4, 1000)
297
- pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
298
  for ii in range(len(input_data)):
299
  for i in range(6):
300
- pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
301
 
302
- pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
303
 
304
 
305
- cdf_mixture = np.cumsum(pdf_mixture,1)
306
 
307
- crps_value = measure_crps(cdf_mixture, target_data)
308
 
309
 
310
 
311
  return crps_value
312
 
313
-
314
-
 
1
+ import numpy as np
2
+ import pandas as pd
3
  import torch
 
4
  from torch import nn, optim
5
+ from torch.utils.data import DataLoader, TensorDataset
6
  from torch.optim import lr_scheduler
 
 
 
 
 
 
 
7
  from scipy.stats import norm
8
+ from loguru import logger
9
+ from tqdm import tqdm # Import tqdm for progress bars
10
+
11
+ # Local imports
12
+ from temps.utils import maximum_mean_discrepancy
13
+
14
+
15
+ class TempsModule:
16
+ """Class for managing temperature-related models and training."""
17
+
18
+ def __init__(
19
+ self,
20
+ model_f,
21
+ model_z,
22
+ batch_size=100,
23
+ rejection_param=1,
24
+ da=True,
25
+ verbose=False,
26
+ ):
27
+ self.model_z = model_z
28
+ self.model_f = model_f
29
+ self.da = da
30
+ self.verbose = verbose
31
+ self.ngaussians = model_z.ngaussians
32
+
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ self.batch_size = batch_size
35
+ self.rejection_parameter = rejection_param
36
+
37
+ def _get_dataloaders(
38
+ self, input_data, target_data, input_data_da=None, val_fraction=0.1
39
+ ):
40
+ """Create training and validation dataloaders."""
41
  input_data = torch.Tensor(input_data)
42
  target_data = torch.Tensor(target_data)
43
+ input_data_da = (
44
+ torch.Tensor(input_data_da)
45
+ if input_data_da is not None
46
+ else input_data.clone()
47
+ )
48
+
49
+ dataset = TensorDataset(input_data, input_data_da, target_data)
50
+ train_dataset, val_dataset = torch.utils.data.random_split(
51
+ dataset,
52
+ [int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
53
+ )
54
+ loader_train = DataLoader(
55
+ train_dataset, batch_size=self.batch_size, shuffle=True
56
+ )
57
+ loader_val = DataLoader(val_dataset, batch_size=64, shuffle=True)
58
 
59
  return loader_train, loader_val
60
 
61
+ def _loss_function(self, mean, std, logmix, true):
62
+ """Compute the loss function."""
63
+ log_prob = (
64
+ logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
65
+ )
66
+ log_prob = torch.logsumexp(log_prob, dim=1)
 
67
  loss = -log_prob.mean()
68
+ return loss
69
+
70
+ def _loss_function_da(self, f1, f2):
71
+ """Compute the KL divergence loss for domain adaptation."""
72
+ kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
73
  loss = kl_loss(f1, f2)
74
+ return torch.log(loss)
 
 
 
 
75
 
76
+ def _to_numpy(self, x):
77
+ """Convert a tensor to a NumPy array."""
78
  return x.detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def train(
81
+ self,
82
+ input_data,
83
+ input_data_da,
84
+ target_data,
85
+ nepochs=10,
86
+ step_size=100,
87
+ val_fraction=0.1,
88
+ lr=1e-3,
89
+ weight_decay=0,
90
+ ):
91
+ """Train the models using provided data."""
92
+ self.model_z.train()
93
+ self.model_f.train()
94
+
95
+ loader_train, loader_val = self._get_dataloaders(
96
+ input_data, target_data, input_data_da, val_fraction
97
+ )
98
+ optimizer_z = optim.Adam(
99
+ self.model_z.parameters(), lr=lr, weight_decay=weight_decay
100
+ )
101
+ optimizer_f = optim.Adam(
102
+ self.model_f.parameters(), lr=lr, weight_decay=weight_decay
103
+ )
104
+
105
+ scheduler_z = lr_scheduler.StepLR(optimizer_z, step_size=step_size, gamma=0.1)
106
+ scheduler_f = lr_scheduler.StepLR(optimizer_f, step_size=step_size, gamma=0.1)
107
+
108
+ self.model_z.to(self.device)
109
+ self.model_f.to(self.device)
110
+
111
+ loss_train, loss_validation = [], []
112
 
113
+ for epoch in range(nepochs):
114
+ _loss_train, _loss_validation = [], []
115
+ logger.info(f"Epoch {epoch + 1}/{nepochs} starting...")
116
+ for input_data, input_data_da, target_data in tqdm(
117
+ loader_train, desc="Training", unit="batch"
118
+ ):
119
+ input_data, target_data = input_data.to(self.device), target_data.to(
120
+ self.device
121
+ )
122
  if self.da:
123
  input_data_da = input_data_da.to(self.device)
124
 
125
+ optimizer_f.zero_grad()
126
+ optimizer_z.zero_grad()
 
127
 
128
+ features = self.model_f(input_data)
129
+ features_da = self.model_f(input_data_da) if self.da else None
 
130
 
131
+ mu, logsig, logmix_coeff = self.model_z(features)
132
+ logsig = torch.clamp(logsig, -6, 2)
133
  sig = torch.exp(logsig)
134
 
135
+ loss_z = self._loss_function(mu, sig, logmix_coeff, target_data)
136
+ loss = loss_z + (
137
+ 1e3
138
+ * maximum_mean_discrepancy(
139
+ features, features_da, kernel_type="rbf"
140
+ ).sum()
141
+ if self.da
142
+ else 0
143
+ )
144
+
145
+ _loss_train.append(loss_z.item())
146
+ loss.backward()
147
+ optimizer_f.step()
148
+ optimizer_z.step()
149
 
150
+ scheduler_f.step()
151
+ scheduler_z.step()
152
 
153
+ loss_train.append(np.mean(_loss_train))
154
+ _loss_validation = self._validate(loader_val, target_data)
155
+
156
+ logger.info(
157
+ f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
158
+ )
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ def _validate(self, loader_val, target_data):
161
+ """Validate the model on the validation dataset."""
162
+ self.model_z.eval()
163
+ self.model_f.eval()
164
+ _loss_validation = []
165
 
166
+ with torch.no_grad():
167
+ for input_data, _, target_data in tqdm(
168
+ loader_val, desc="Validating", unit="batch"
169
+ ):
170
  input_data = input_data.to(self.device)
171
  target_data = target_data.to(self.device)
172
 
173
+ features = self.model_f(input_data)
174
+ mu, logsig, logmix_coeff = self.model_z(features)
175
+ logsig = torch.clamp(logsig, -6, 2)
 
 
176
  sig = torch.exp(logsig)
177
 
178
  loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
179
  _loss_validation.append(loss_val.item())
180
 
181
+ return _loss_validation
 
 
 
 
 
 
182
 
183
  def get_features(self, input_data):
184
+ """Get features from the model."""
185
+ self.model_f.eval()
 
186
  input_data = input_data.to(self.device)
187
+ features = self.model_f(input_data)
188
+ return self._to_numpy(features)
 
 
 
189
 
190
+ def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
191
+ """Get the predicted z values and their uncertainties."""
192
+ logger.info("Predicting photo-z for the input galaxies...")
193
+ self.model_z.eval()
194
+ self.model_f.eval()
195
 
196
  input_data = input_data.to(self.device)
197
+ features = self.model_f(input_data)
198
+ mu, logsig, logmix_coeff = self.model_z(features)
199
+ logsig = torch.clamp(logsig, -6, 2)
 
 
200
  sig = torch.exp(logsig)
201
 
202
  mix_coeff = torch.exp(logmix_coeff)
203
+ z = (mix_coeff * mu).sum(dim=1)
204
+ zerr = torch.sqrt(
205
+ (mix_coeff * sig**2).sum(dim=1)
206
+ + (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
207
+ )
208
 
209
+ mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ if return_pz:
212
+ logger.info("Returning p(z)")
213
+ return self._calculate_pdf(z, mu, sig, mix_coeff, return_flag)
214
  else:
215
+ return self._to_numpy(z), self._to_numpy(zerr)
216
+
217
+ def _calculate_pdf(self, z, mu, sig, mix_coeff, return_flag):
218
+ """Calculate the probability density function."""
219
+ zgrid = np.linspace(0, 5, 1000)
220
+ pz = np.zeros((len(z), len(zgrid)))
221
+
222
+ for ii in range(len(z)):
223
+ for i in range(self.ngaussians):
224
+ pz[ii] += mix_coeff[ii, i] * norm.pdf(
225
+ zgrid, mu[ii, i], sig[ii, i]
226
+ )
227
+
228
+ if return_flag:
229
+ logger.info("Calculating and returning ODDS")
230
+ pz /= pz.sum(axis=1, keepdims=True)
231
+ return self._calculate_odds(z, pz, zgrid)
232
+ return self._to_numpy(z), pz
233
+
234
+ def _calculate_odds(self, z, pz, zgrid):
235
+ """Calculate odds based on the PDF."""
236
+ logger.info('Calculating ODDS values')
237
+ diff_matrix = np.abs(self._to_numpy(z)[:, None] - zgrid[None, :])
238
+ idx_peak = np.argmax(pz, axis=1)
239
+ zpeak = zgrid[idx_peak]
240
+ idx_upper = np.argmin(np.abs((zpeak + 0.05)[:, None] - zgrid[None, :]), axis=1)
241
+ idx_lower = np.argmin(np.abs((zpeak - 0.05)[:, None] - zgrid[None, :]), axis=1)
242
+
243
+ odds = []
244
+ for jj in range(len(pz)):
245
+ odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
246
+
247
+ odds = np.array(odds)
248
+ return self._to_numpy(z), pz, odds
249
+
250
+ def calculate_pit(self, input_data, target_data):
251
+ logger.info('Calculating PIT values')
252
 
253
  pit_list = []
254
 
255
+ self.model_f = self.model_f.eval()
256
+ self.model_f = self.model_f.to(self.device)
257
+ self.model_z = self.model_z.eval()
258
+ self.model_z = self.model_z.to(self.device)
259
 
260
  input_data = input_data.to(self.device)
261
 
262
 
263
+ features = self.model_f(input_data)
264
+ mu, logsig, logmix_coeff = self.model_z(features)
265
 
266
  logsig = torch.clamp(logsig,-6,2)
267
  sig = torch.exp(logsig)
 
277
 
278
  return pit_list
279
 
280
+ def calculate_crps(self, input_data, target_data):
281
+ logger.info('Calculating CRPS values')
282
 
283
  def measure_crps(cdf, t):
284
  zgrid = np.linspace(0,4,1000)
 
292
 
293
  crps_list = []
294
 
295
+ self.model_f = self.model_f.eval()
296
+ self.model_f = self.model_f.to(self.device)
297
+ self.model_z = self.model_z.eval()
298
+ self.model_z = self.model_z.to(self.device)
299
 
300
  input_data = input_data.to(self.device)
301
 
302
 
303
+ features = self.model_f(input_data)
304
+ mu, logsig, logmix_coeff = self.model_z(features)
305
  logsig = torch.clamp(logsig,-6,2)
306
  sig = torch.exp(logsig)
307
 
 
313
  z = (mix_coeff * mu).sum(1)
314
 
315
  x = np.linspace(0, 4, 1000)
316
+ pz = np.zeros(shape=(len(target_data), len(x)))
317
  for ii in range(len(input_data)):
318
  for i in range(6):
319
+ pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
320
 
321
+ pz = pz / pz.sum(1)[:,None]
322
 
323
 
324
+ cdf_z = np.cumsum(pz,1)
325
 
326
+ crps_value = measure_crps(cdf_z, target_data)
327
 
328
 
329
 
330
  return crps_value
331
 
 
 
temps/temps_arch.py CHANGED
@@ -20,52 +20,46 @@ class EncoderPhotometry(nn.Module):
20
  nn.Linear(50, 20),
21
  nn.Dropout(dropout_prob),
22
  nn.ReLU(),
23
- nn.Linear(20, 10)
24
  )
25
-
26
  def forward(self, x):
27
  f = self.features(x)
28
- f = F.log_softmax(f, dim=1)
29
  return f
30
 
31
-
32
 
33
  class MeasureZ(nn.Module):
34
  def __init__(self, num_gauss=10, dropout_prob=0):
35
  super(MeasureZ, self).__init__()
36
-
37
- self.ngaussians=num_gauss
38
  self.measure_mu = nn.Sequential(
39
  nn.Linear(10, 20),
40
  nn.Dropout(dropout_prob),
41
  nn.ReLU(),
42
- nn.Linear(20, num_gauss)
43
  )
44
 
45
  self.measure_coeffs = nn.Sequential(
46
  nn.Linear(10, 20),
47
  nn.Dropout(dropout_prob),
48
  nn.ReLU(),
49
- nn.Linear(20, num_gauss)
50
  )
51
 
52
  self.measure_sigma = nn.Sequential(
53
  nn.Linear(10, 20),
54
  nn.Dropout(dropout_prob),
55
  nn.ReLU(),
56
- nn.Linear(20, num_gauss)
57
  )
58
-
59
-
60
  def forward(self, f):
61
  mu = self.measure_mu(f)
62
  sigma = self.measure_sigma(f)
63
  logmix_coeff = self.measure_coeffs(f)
64
-
65
- logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
66
-
67
- return mu, sigma, logmix_coeff
68
 
69
-
70
-
71
 
 
 
20
  nn.Linear(50, 20),
21
  nn.Dropout(dropout_prob),
22
  nn.ReLU(),
23
+ nn.Linear(20, 10),
24
  )
25
+
26
  def forward(self, x):
27
  f = self.features(x)
28
+ f = F.log_softmax(f, dim=1)
29
  return f
30
 
 
31
 
32
  class MeasureZ(nn.Module):
33
  def __init__(self, num_gauss=10, dropout_prob=0):
34
  super(MeasureZ, self).__init__()
35
+
36
+ self.ngaussians = num_gauss
37
  self.measure_mu = nn.Sequential(
38
  nn.Linear(10, 20),
39
  nn.Dropout(dropout_prob),
40
  nn.ReLU(),
41
+ nn.Linear(20, num_gauss),
42
  )
43
 
44
  self.measure_coeffs = nn.Sequential(
45
  nn.Linear(10, 20),
46
  nn.Dropout(dropout_prob),
47
  nn.ReLU(),
48
+ nn.Linear(20, num_gauss),
49
  )
50
 
51
  self.measure_sigma = nn.Sequential(
52
  nn.Linear(10, 20),
53
  nn.Dropout(dropout_prob),
54
  nn.ReLU(),
55
+ nn.Linear(20, num_gauss),
56
  )
57
+
 
58
  def forward(self, f):
59
  mu = self.measure_mu(f)
60
  sigma = self.measure_sigma(f)
61
  logmix_coeff = self.measure_coeffs(f)
 
 
 
 
62
 
63
+ logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:, None]
 
64
 
65
+ return mu, sigma, logmix_coeff
temps/utils.py CHANGED
@@ -3,113 +3,22 @@ import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from scipy import stats
5
  import torch
6
- from scipy.stats import gaussian_kde
7
 
8
- def nmad(data):
9
- return 1.4826 * np.median(np.abs(data - np.median(data)))
10
-
11
- def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
12
-
13
- def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
14
- #plot properties
15
- plt.rcParams['font.family'] = 'serif'
16
- plt.rcParams['font.size'] = 12
17
-
18
 
 
 
19
 
20
-
21
- bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
22
- print(bin_edges)
23
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
24
- plt.figure(figsize=(6, 5))
25
-
26
- for i, df in enumerate(df_list):
27
- ydata, xlab = [], []
28
-
29
- for k in range(len(bin_edges)-1):
30
- edge_min = bin_edges[k]
31
- edge_max = bin_edges[k+1]
32
-
33
- mean_mag = (edge_max + edge_min) / 2
34
-
35
- if type_bin == 'bin':
36
- df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
37
- elif type_bin == 'cum':
38
- df_plot = df[(df[xvariable] < edge_max)]
39
- else:
40
- raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
41
-
42
- xlab.append(mean_mag)
43
- if metric == 'sig68':
44
- ydata.append(sigma68(df_plot.zwerr))
45
- elif metric == 'bias':
46
- ydata.append(np.mean(df_plot.zwerr))
47
- elif metric == 'nmad':
48
- ydata.append(nmad(df_plot.zwerr))
49
- elif metric == 'outliers':
50
- ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
51
-
52
- print(ydata)
53
- color = cmap(i) # Get a different color for each dataframe
54
- plt.plot(xlab, ydata, ls='-', marker='.', lw=1, label=f'{label_list[i]}', color=color)
55
-
56
- if xvariable == 'VISmag':
57
- xvariable_lab = 'VIS'
58
-
59
-
60
-
61
- plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
62
- plt.xlabel(f'{xvariable_lab}', fontsize=16)
63
- plt.grid(False)
64
- plt.legend()
65
-
66
- if save==True:
67
- plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
68
- plt.show()
69
-
70
-
71
- def plot_nz(df, bins=np.arange(0,5,0.2)):
72
- kwargs=dict( bins=bins,alpha=0.5)
73
- plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
74
- counts, _, =np.histogram(df.z.values, bins=bins)
75
-
76
- plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
77
-
78
- #plt.legend(fontsize=14)
79
- plt.xlabel(r'Redshift', fontsize=14)
80
- plt.ylabel(r'Counts', fontsize=14)
81
- plt.yscale('log')
82
-
83
- plt.show()
84
-
85
- return
86
-
87
-
88
- def plot_scatter(df, sample='specz', save=True):
89
- # Calculate the point density
90
- xy = np.vstack([df.zs.values,df.z.values])
91
- zd = gaussian_kde(xy)(xy)
92
-
93
- fig, ax = plt.subplots()
94
- plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
95
- plt.xlim(0,5)
96
- plt.ylim(0,5)
97
 
98
- plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
99
- plt.ylabel('$z$', fontsize = 14)
100
 
101
- plt.xticks(fontsize = 12)
102
- plt.yticks(fontsize = 12)
103
 
104
- if save==True:
105
- plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
106
 
107
- plt.show()
108
-
109
-
110
-
111
 
112
- def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
113
  """
114
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
115
 
@@ -130,7 +39,8 @@ def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num
130
  mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
131
  return mmd_loss
132
 
133
- def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
 
134
  """
135
  Compute the kernel matrix based on the chosen kernel type.
136
 
@@ -151,73 +61,77 @@ def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
151
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
152
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
153
 
154
- kernel_input = (x - y).pow(2).mean(2) # Pairwise squared Euclidean distances
155
 
156
- if kernel_type == 'linear':
157
  kernel_matrix = kernel_input
158
- elif kernel_type == 'poly':
159
  kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
160
- elif kernel_type == 'rbf':
161
  kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
162
- elif kernel_type == 'sigmoid':
163
  kernel_matrix = torch.tanh(kernel_mul * kernel_input)
164
  else:
165
- raise ValueError("Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'.")
 
 
166
 
167
  return kernel_matrix
168
 
169
 
170
- def select_cut(df,
171
- completenss_lim=None,
172
- nmad_lim = None,
173
- outliers_lim=None,
174
- return_df=False):
175
-
176
-
177
- if (completenss_lim is None)&(nmad_lim is None)&(outliers_lim is None):
178
- raise(ValueError("Select at least one cut"))
179
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
180
  raise ValueError("Select only one cut at a time")
181
-
182
  else:
183
- bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.1))
184
- scatter, eta, cmptnss, nobj = [],[],[], []
185
 
186
- for k in range(len(bin_edges)-1):
187
  edge_min = bin_edges[k]
188
- edge_max = bin_edges[k+1]
189
 
190
- df_bin = df[(df.zflag > edge_min)]
191
-
192
 
193
- cmptnss.append(np.round(len(df_bin)/len(df),2)*100)
194
  scatter.append(nmad(df_bin.zwerr))
195
- eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df_bin)*100)
196
  nobj.append(len(df_bin))
197
-
198
- dfcuts = pd.DataFrame(data=np.c_[np.round(bin_edges[:-1],5), np.round(nobj,1), np.round(cmptnss,1), np.round(scatter,3), np.round(eta,2)], columns=['flagcut', 'Nobj','completeness', 'nmad', 'eta'])
199
-
 
 
 
 
 
 
 
 
 
200
  if completenss_lim is not None:
201
- print('Selecting cut based on completeness')
202
- selected_cut = dfcuts[dfcuts['completeness'] <= completenss_lim].iloc[0]
203
-
204
-
205
  elif nmad_lim is not None:
206
- print('Selecting cut based on nmad')
207
- selected_cut = dfcuts[dfcuts['nmad'] <= nmad_lim].iloc[0]
208
 
209
-
210
  elif outliers_lim is not None:
211
- print('Selecting cut based on outliers')
212
- selected_cut = dfcuts[dfcuts['eta'] <= outliers_lim].iloc[0]
213
 
 
 
 
214
 
215
- print(f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}")
216
-
217
- df_cut = df[(df.zflag > selected_cut['flagcut'])]
218
- if return_df==True:
219
- return df_cut, selected_cut['flagcut'], dfcuts
220
  else:
221
- return selected_cut['flagcut'], dfcuts
222
-
223
-
 
3
  import matplotlib.pyplot as plt
4
  from scipy import stats
5
  import torch
6
+ from loguru import logger
7
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def caluclate_eta(df):
10
+ return len(df[np.abs(df.zwerr)>0.15])/len(df) *100
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def nmad(data):
14
+ return 1.4826 * np.median(np.abs(data - np.median(data)))
15
 
 
 
16
 
17
+ def sigma68(data):
18
+ return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
19
 
 
 
 
 
20
 
21
+ def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
22
  """
23
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
24
 
 
39
  mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
40
  return mmd_loss
41
 
42
+
43
+ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
44
  """
45
  Compute the kernel matrix based on the chosen kernel type.
46
 
 
61
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
62
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
63
 
64
+ kernel_input = (x - y).pow(2).mean(2)
65
 
66
+ if kernel_type == "linear":
67
  kernel_matrix = kernel_input
68
+ elif kernel_type == "poly":
69
  kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
70
+ elif kernel_type == "rbf":
71
  kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
72
+ elif kernel_type == "sigmoid":
73
  kernel_matrix = torch.tanh(kernel_mul * kernel_input)
74
  else:
75
+ raise ValueError(
76
+ "Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'."
77
+ )
78
 
79
  return kernel_matrix
80
 
81
 
82
+ def select_cut(
83
+ df, completenss_lim=None, nmad_lim=None, outliers_lim=None, return_df=False
84
+ ):
85
+
86
+ if (completenss_lim is None) & (nmad_lim is None) & (outliers_lim is None):
87
+ raise (ValueError("Select at least one cut"))
 
 
 
88
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
89
  raise ValueError("Select only one cut at a time")
90
+
91
  else:
92
+ bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
93
+ scatter, eta, cmptnss, nobj = [], [], [], []
94
 
95
+ for k in range(len(bin_edges) - 1):
96
  edge_min = bin_edges[k]
97
+ edge_max = bin_edges[k + 1]
98
 
99
+ df_bin = df[(df.odds > edge_min)]
 
100
 
101
+ cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
102
  scatter.append(nmad(df_bin.zwerr))
103
+ eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
104
  nobj.append(len(df_bin))
105
+
106
+ dfcuts = pd.DataFrame(
107
+ data=np.c_[
108
+ np.round(bin_edges[:-1], 5),
109
+ np.round(nobj, 1),
110
+ np.round(cmptnss, 1),
111
+ np.round(scatter, 3),
112
+ np.round(eta, 2),
113
+ ],
114
+ columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
115
+ )
116
+
117
  if completenss_lim is not None:
118
+ logger.info("Selecting cut based on completeness")
119
+ selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
120
+
 
121
  elif nmad_lim is not None:
122
+ logger.info("Selecting cut based on nmad")
123
+ selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
124
 
 
125
  elif outliers_lim is not None:
126
+ logger.info("Selecting cut based on outliers")
127
+ selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
128
 
129
+ logger.info(
130
+ f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
131
+ )
132
 
133
+ df_cut = df[(df.odds > selected_cut["flagcut"])]
134
+ if return_df == True:
135
+ return df_cut, selected_cut["flagcut"], dfcuts
 
 
136
  else:
137
+ return selected_cut["flagcut"], dfcuts