lauracabayol commited on
Commit
c212435
·
0 Parent(s):

Archive code and network training

Browse files
insight/.ipynb_checkpoints/archive-checkpoint.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from astropy.io import fits
4
+ import os
5
+ from astropy.table import Table
6
+ from scipy.spatial import KDTree
7
+
8
+ import matplotlib.pyplot as plt
9
+
10
+ from matplotlib import rcParams
11
+ rcParams["mathtext.fontset"] = "stix"
12
+ rcParams["font.family"] = "STIXGeneral"
13
+
14
+
15
+ class archive():
16
+ def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
17
+
18
+ self.aperture = aperture
19
+
20
+ self.weight_dict={(-99,0.99):0,
21
+ (1,1.99):0.5,
22
+ (2,2.99):0.75,
23
+ (3,4):1,
24
+ (9,9.99):0.25,
25
+ (10,10.99):0,
26
+ (11,11.99):0.5,
27
+ (12,12.99):0.75,
28
+ (13,14):1,
29
+ (14.01,40):0
30
+ }
31
+
32
+ filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
33
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
34
+ filename_gold='Export_Gold_2023_07_03.csv'
35
+
36
+ hdu_list = fits.open(os.path.join(path,filename_calib))
37
+ cat = Table(hdu_list[1].data).to_pandas()
38
+
39
+ hdu_list = fits.open(os.path.join(path,filename_valid))
40
+ cat_test = Table(hdu_list[1].data).to_pandas()
41
+
42
+ gold_sample = pd.read_csv(os.path.join(path,filename_gold))
43
+
44
+ #cat_test = self._match_gold_sample(cat_test,gold_sample)
45
+
46
+ if drop_stars==True:
47
+ cat = cat[cat.mu_class_L07==1]
48
+
49
+ if clean_photometry==True:
50
+ cat = self._clean_photometry(cat)
51
+ cat_test = self._clean_photometry(cat_test)
52
+
53
+ self._get_loss_weights(cat)
54
+
55
+ cat = cat[cat.w_Q_f_S15>0]
56
+
57
+ self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
58
+
59
+
60
+ self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
61
+
62
+ self._get_loss_weights(cat)
63
+
64
+ #self.cat_test=cat_test
65
+ #self.cat_train=cat
66
+
67
+ def _extract_fluxes(self,catalogue):
68
+ columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
69
+ columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
70
+
71
+ f = catalogue[columns_f].values
72
+ ferr = catalogue[columns_ferr].values
73
+ return f, ferr
74
+
75
+ def _to_colors(self, flux, fluxerr):
76
+ """ Convert fluxes to colors"""
77
+ color = flux[:,:-1] / flux[:,1:]
78
+ color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
79
+ return color,color_err
80
+
81
+ def _clean_photometry(self,catalogue):
82
+ """ Drops all object with FLAG_PHOT!=0"""
83
+ catalogue = catalogue[catalogue['FLAG_PHOT']==0]
84
+
85
+ return catalogue
86
+
87
+ def _correct_extinction(self,catalogue, f):
88
+ """Corrects for extinction"""
89
+ ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
90
+ ext_correction = catalogue[ext_correction_cols].values
91
+
92
+ f = f * ext_correction
93
+ return f
94
+
95
+ def _take_only_zspec(self,catalogue,cat_flag=None):
96
+ """Selects only galaxies with spectroscopic redshift"""
97
+ if cat_flag=='Calib':
98
+ catalogue = catalogue[catalogue.z_spec_S15>0]
99
+ elif cat_flag=='Valid':
100
+ catalogue = catalogue[catalogue.z_spec_S15>0]
101
+ return catalogue
102
+
103
+ def _clean_zspec_sample(self,catalogue ,kind=None):
104
+ if kind==None:
105
+ return catalogue
106
+ elif kind=='Total':
107
+ return catalogue[catalogue['reliable_S15']>0]
108
+ elif kind=='Partial':
109
+ return catalogue[(catalogue['w_Q_f_S15']>0.5)]
110
+
111
+ def _map_weight(self,Qz):
112
+ for key, value in self.weight_dict.items():
113
+ if key[0] <= Qz <= key[1]:
114
+ return value
115
+
116
+ def _get_loss_weights(self,catalogue):
117
+ catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight)
118
+
119
+ def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
120
+ max_distance_deg = max_distance_arcsec / 3600.0
121
+
122
+ gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
123
+ valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
124
+
125
+ kdtree = KDTree(gold_sample_radec)
126
+ distances, indices = kdtree.query(valid_sample_radec, k=1)
127
+
128
+ specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
129
+
130
+ zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
131
+
132
+ catalogue_valid['z_spec_gold'] = zs
133
+
134
+ return catalogue_valid
135
+
136
+
137
+ def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
138
+
139
+ if only_zspec:
140
+ catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
141
+ catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
142
+
143
+ self.cat_train=catalogue
144
+ f, ferr = self._extract_fluxes(catalogue)
145
+
146
+
147
+ if extinction_corr==True:
148
+ f = self._correct_extinction(catalogue,f)
149
+
150
+ if convert_colors==True:
151
+ col, colerr = self._to_colors(f, ferr)
152
+
153
+ self.phot_train = col
154
+ self.photerr_train = colerr
155
+ else:
156
+ self.phot_train = f
157
+ self.photerr_train = ferr
158
+
159
+ self.target_z_train = catalogue['z_spec_S15'].values
160
+ self.target_qz_train = catalogue['w_Q_f_S15'].values
161
+
162
+ def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
163
+
164
+ if only_zspec:
165
+ catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
166
+ catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
167
+
168
+ self.cat_test=catalogue
169
+
170
+ f, ferr = self._extract_fluxes(catalogue)
171
+
172
+
173
+ if extinction_corr==True:
174
+ f = self._correct_extinction(catalogue,f)
175
+
176
+ if convert_colors==True:
177
+ col, colerr = self._to_colors(f, ferr)
178
+ self.phot_test = col
179
+ self.photerr_test = colerr
180
+ else:
181
+ self.phot_test = f
182
+ self.photerr_test = ferr
183
+
184
+ self.target_z_test = catalogue['z_spec_S15'].values
185
+
186
+
187
+ def get_training_data(self):
188
+ return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train
189
+
190
+ def get_testing_data(self):
191
+ return self.phot_test, self.photerr_test, self.target_z_test
192
+
193
+ def get_VIS_mag(self, catalogue):
194
+ return catalogue[['MAG_VIS']].values
195
+
196
+ def plot_zdistribution(self, plot_test=False, bins=50):
197
+ _,_,specz = photoz_archive.get_training_data()
198
+ plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
199
+
200
+ if plot_test:
201
+ _,_,specz_test = photoz_archive.get_training_data()
202
+ plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
203
+
204
+
205
+ plt.xticks(fontsize=12)
206
+ plt.yticks(fontsize=12)
207
+
208
+ plt.xlabel(r'Redshift', fontsize=14)
209
+ plt.ylabel('Counts', fontsize=14)
210
+
211
+ plt.show()
insight/.ipynb_checkpoints/insight-checkpoint.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, dataset, TensorDataset
3
+ from torch import nn, optim
4
+ from torch.optim import lr_scheduler
5
+ import numpy as np
6
+ import pandas as pd
7
+ from astropy.io import fits
8
+ import os
9
+ from astropy.table import Table
10
+ from scipy.spatial import KDTree
11
+
12
+ class Insight_module():
13
+ """ Define class"""
14
+
15
+ def __init__(self, model):
16
+ self.model=model
17
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
20
+ input_data = torch.Tensor(input_data)
21
+ target_data = torch.Tensor(target_data)
22
+ target_weights = torch.Tensor(target_weights)
23
+
24
+ dataset = TensorDataset(input_data, target_data, target_weights)
25
+
26
+ trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
27
+ loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
28
+ loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
29
+
30
+ return loader_train, loader_val
31
+
32
+
33
+ def _loss_function(self,mean, std, logmix, true, target_weights):
34
+ log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
35
+
36
+ log_prob = torch.logsumexp(log_prob, 1)
37
+
38
+ #log_prob = log_prob * target_weights
39
+ loss = -log_prob.mean()
40
+
41
+ return loss
42
+
43
+ def _to_numpy(self,x):
44
+ return x.detach().cpu().numpy()
45
+
46
+ def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
47
+ self.model = self.model.train()
48
+ loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
49
+ optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
50
+
51
+ self.model = self.model.to(self.device)
52
+
53
+ loss_train, loss_validation = [],[]
54
+
55
+ for epoch in range(nepochs):
56
+ for input_data, target_data, target_weights in loader_train:
57
+
58
+ input_data = input_data.to(self.device)
59
+ target_data = target_data.to(self.device)
60
+ target_weights = target_weights.to(self.device)
61
+
62
+
63
+ optimizer.zero_grad()
64
+
65
+ mu, logsig, logmix_coeff = self.model(input_data)
66
+ logsig = torch.clamp(logsig,-6,2)
67
+ sig = torch.exp(logsig)
68
+
69
+
70
+ #print(mu,sig,target_data,torch.exp(logmix_coeff))
71
+
72
+ loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
73
+
74
+ loss.backward()
75
+ optimizer.step()
76
+
77
+ loss_train.append(loss.item())
78
+
79
+ for input_data, target_data, target_weights in loader_val:
80
+
81
+
82
+ input_data = input_data.to(self.device)
83
+ target_data = target_data.to(self.device)
84
+ target_weights = target_weights.to(self.device)
85
+
86
+
87
+ mu, logsig, logmix_coeff = self.model(input_data)
88
+ logsig = torch.clamp(logsig,-6,2)
89
+ sig = torch.exp(logsig)
90
+
91
+ loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
92
+ loss_validation.append(loss_val.item())
93
+
94
+ print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
95
+
96
+ self.loss_train=loss_train
97
+ self.loss_validation=loss_validation
98
+
99
+
100
+ def get_photoz(self,input_data, target_data):
101
+ self.model = self.model.eval()
102
+ self.model = self.model.to(self.device)
103
+
104
+ input_data = input_data.to(self.device)
105
+ target_data = target_data.to(self.device)
106
+
107
+ for ii in range(len(input_data)):
108
+
109
+ mu, logsig, logmix_coeff = self.model(input_data)
110
+ logsig = torch.clamp(logsig,-6,2)
111
+ sig = torch.exp(logsig)
112
+
113
+ mix_coeff = torch.exp(logmix_coeff)
114
+
115
+ z = (mix_coeff * mu).sum(1)
116
+ zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
117
+
118
+
119
+ return self._to_numpy(z),self._to_numpy(zerr)
120
+
121
+
122
+ #return model
123
+
124
+ def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
125
+ bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
126
+ ydata,xlab = [],[]
127
+
128
+
129
+ for k in range(len(bin_edges)-1):
130
+ edge_min = bin_edges[k]
131
+ edge_max = bin_edges[k+1]
132
+
133
+ mean_mag = (edge_max + edge_min) / 2
134
+
135
+ if type_bin=='bin':
136
+ df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
137
+ elif type_bin=='cum':
138
+ df_plot = df_test[(df_test.imag < edge_max)]
139
+ else:
140
+ raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
141
+
142
+
143
+ xlab.append(mean_mag)
144
+ if metric=='sig68':
145
+ ydata.append(sigma68(df_plot.zwerr))
146
+ elif metric=='bias':
147
+ ydata.append(np.mean(df_plot.zwerr))
148
+ elif metric=='nmad':
149
+ ydata.append(nmad(df_plot.zwerr))
150
+ elif metric=='outliers':
151
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
152
+
153
+ plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
154
+ plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
155
+ plt.xlabel(f'{xvariable}', fontsize = 16)
156
+
157
+ plt.xticks(fontsize = 14)
158
+ plt.yticks(fontsize = 14)
159
+
160
+ plt.grid(False)
161
+
162
+ plt.show()
163
+
164
+
165
+
166
+
insight/.ipynb_checkpoints/insight_arch-checkpoint.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, optim
2
+ import torch
3
+ class Photoz_network(nn.Module):
4
+ def __init__(self, num_gauss=10, dropout_prob=0):
5
+ super(Photoz_network, self).__init__()
6
+
7
+ self.features = nn.Sequential(
8
+ nn.Linear(6, 10),
9
+ nn.Dropout(dropout_prob),
10
+ nn.ReLU(),
11
+ nn.Linear(10, 30),
12
+ nn.Dropout(dropout_prob),
13
+ nn.ReLU(),
14
+ nn.Linear(30, 50),
15
+ nn.Dropout(dropout_prob),
16
+ nn.ReLU(),
17
+ nn.Linear(50, 70),
18
+ nn.Dropout(dropout_prob),
19
+ nn.ReLU(),
20
+ nn.Linear(70, 100)
21
+ )
22
+
23
+ self.measure_mu = nn.Sequential(
24
+ nn.Linear(100, 80),
25
+ nn.Dropout(dropout_prob),
26
+ nn.ReLU(),
27
+ nn.Linear(80, 70),
28
+ nn.Dropout(dropout_prob),
29
+ nn.ReLU(),
30
+ nn.Linear(70, 60),
31
+ nn.Dropout(dropout_prob),
32
+ nn.ReLU(),
33
+ nn.Linear(60, 50),
34
+ nn.Dropout(dropout_prob),
35
+ nn.ReLU(),
36
+ nn.Linear(50, num_gauss)
37
+ )
38
+
39
+ self.measure_coeffs = nn.Sequential(
40
+ nn.Linear(100, 80),
41
+ nn.Dropout(dropout_prob),
42
+ nn.ReLU(),
43
+ nn.Linear(80, 70),
44
+ nn.Dropout(dropout_prob),
45
+ nn.ReLU(),
46
+ nn.Linear(70, 60),
47
+ nn.Dropout(dropout_prob),
48
+ nn.ReLU(),
49
+ nn.Linear(60, 50),
50
+ nn.Dropout(dropout_prob),
51
+ nn.ReLU(),
52
+ nn.Linear(50, num_gauss)
53
+ )
54
+
55
+ self.measure_sigma = nn.Sequential(
56
+ nn.Linear(100, 80),
57
+ nn.Dropout(dropout_prob),
58
+ nn.ReLU(),
59
+ nn.Linear(80, 70),
60
+ nn.Dropout(dropout_prob),
61
+ nn.ReLU(),
62
+ nn.Linear(70, 60),
63
+ nn.Dropout(dropout_prob),
64
+ nn.ReLU(),
65
+ nn.Linear(60, 50),
66
+ nn.Dropout(dropout_prob),
67
+ nn.ReLU(),
68
+ nn.Linear(50, num_gauss)
69
+ )
70
+
71
+ def forward(self, x):
72
+ f = self.features(x)
73
+ mu = self.measure_mu(f)
74
+ sigma = self.measure_sigma(f)
75
+ logmix_coeff = self.measure_coeffs(f)
76
+
77
+ logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
78
+
79
+ return mu, sigma, logmix_coeff
80
+
81
+
insight/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from scipy import stats
5
+
6
+ def nmad(data):
7
+ return 1.4826 * np.median(np.abs(data - np.median(data)))
8
+
9
+ def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
+
11
+
12
+ def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
13
+ bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
14
+ ydata,xlab = [],[]
15
+
16
+
17
+ for k in range(len(bin_edges)-1):
18
+ edge_min = bin_edges[k]
19
+ edge_max = bin_edges[k+1]
20
+
21
+ mean_mag = (edge_max + edge_min) / 2
22
+
23
+ if type_bin=='bin':
24
+ df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
25
+ elif type_bin=='cum':
26
+ df_plot = df_test[(df_test.imag < edge_max)]
27
+ else:
28
+ raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
29
+
30
+
31
+ xlab.append(mean_mag)
32
+ if metric=='sig68':
33
+ ydata.append(sigma68(df_plot.zwerr))
34
+ elif metric=='bias':
35
+ ydata.append(np.mean(df_plot.zwerr))
36
+ elif metric=='nmad':
37
+ ydata.append(nmad(df_plot.zwerr))
38
+ elif metric=='outliers':
39
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
40
+
41
+ plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
42
+ plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
43
+ plt.xlabel(f'{xvariable}', fontsize = 16)
44
+
45
+ plt.xticks(fontsize = 14)
46
+ plt.yticks(fontsize = 14)
47
+
48
+ plt.grid(False)
49
+
50
+ plt.show()
51
+
insight/__pycache__/archive.cpython-310.pyc ADDED
Binary file (6.92 kB). View file
 
insight/__pycache__/archive.cpython-39.pyc ADDED
Binary file (6.2 kB). View file
 
insight/__pycache__/insight.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
insight/__pycache__/insight.cpython-39.pyc ADDED
Binary file (4.4 kB). View file
 
insight/__pycache__/insight_arch.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
insight/__pycache__/insight_arch.cpython-39.pyc ADDED
Binary file (1.58 kB). View file
 
insight/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.71 kB). View file
 
insight/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.67 kB). View file
 
insight/archive.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from astropy.io import fits
4
+ import os
5
+ from astropy.table import Table
6
+ from scipy.spatial import KDTree
7
+
8
+ import matplotlib.pyplot as plt
9
+
10
+ from matplotlib import rcParams
11
+ rcParams["mathtext.fontset"] = "stix"
12
+ rcParams["font.family"] = "STIXGeneral"
13
+
14
+
15
+ class archive():
16
+ def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
17
+
18
+ self.aperture = aperture
19
+
20
+ self.weight_dict={(-99,0.99):0,
21
+ (1,1.99):0.5,
22
+ (2,2.99):0.75,
23
+ (3,4):1,
24
+ (9,9.99):0.25,
25
+ (10,10.99):0,
26
+ (11,11.99):0.5,
27
+ (12,12.99):0.75,
28
+ (13,14):1,
29
+ (14.01,40):0
30
+ }
31
+
32
+ filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
33
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
34
+ filename_gold='Export_Gold_2023_07_03.csv'
35
+
36
+ hdu_list = fits.open(os.path.join(path,filename_calib))
37
+ cat = Table(hdu_list[1].data).to_pandas()
38
+
39
+ hdu_list = fits.open(os.path.join(path,filename_valid))
40
+ cat_test = Table(hdu_list[1].data).to_pandas()
41
+
42
+ gold_sample = pd.read_csv(os.path.join(path,filename_gold))
43
+
44
+ #cat_test = self._match_gold_sample(cat_test,gold_sample)
45
+
46
+ if drop_stars==True:
47
+ cat = cat[cat.mu_class_L07==1]
48
+
49
+ if clean_photometry==True:
50
+ cat = self._clean_photometry(cat)
51
+ cat_test = self._clean_photometry(cat_test)
52
+
53
+ self._get_loss_weights(cat)
54
+
55
+ cat = cat[cat.w_Q_f_S15>0]
56
+
57
+ self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
58
+
59
+
60
+ self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
61
+
62
+ self._get_loss_weights(cat)
63
+
64
+ #self.cat_test=cat_test
65
+ #self.cat_train=cat
66
+
67
+ def _extract_fluxes(self,catalogue):
68
+ columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
69
+ columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
70
+
71
+ f = catalogue[columns_f].values
72
+ ferr = catalogue[columns_ferr].values
73
+ return f, ferr
74
+
75
+ def _to_colors(self, flux, fluxerr):
76
+ """ Convert fluxes to colors"""
77
+ color = flux[:,:-1] / flux[:,1:]
78
+ color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
79
+ return color,color_err
80
+
81
+ def _clean_photometry(self,catalogue):
82
+ """ Drops all object with FLAG_PHOT!=0"""
83
+ catalogue = catalogue[catalogue['FLAG_PHOT']==0]
84
+
85
+ return catalogue
86
+
87
+ def _correct_extinction(self,catalogue, f):
88
+ """Corrects for extinction"""
89
+ ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
90
+ ext_correction = catalogue[ext_correction_cols].values
91
+
92
+ f = f * ext_correction
93
+ return f
94
+
95
+ def _take_only_zspec(self,catalogue,cat_flag=None):
96
+ """Selects only galaxies with spectroscopic redshift"""
97
+ if cat_flag=='Calib':
98
+ catalogue = catalogue[catalogue.z_spec_S15>0]
99
+ elif cat_flag=='Valid':
100
+ catalogue = catalogue[catalogue.z_spec_S15>0]
101
+ return catalogue
102
+
103
+ def _clean_zspec_sample(self,catalogue ,kind=None):
104
+ if kind==None:
105
+ return catalogue
106
+ elif kind=='Total':
107
+ return catalogue[catalogue['reliable_S15']>0]
108
+ elif kind=='Partial':
109
+ return catalogue[(catalogue['w_Q_f_S15']>0.5)]
110
+
111
+ def _map_weight(self,Qz):
112
+ for key, value in self.weight_dict.items():
113
+ if key[0] <= Qz <= key[1]:
114
+ return value
115
+
116
+ def _get_loss_weights(self,catalogue):
117
+ catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight)
118
+
119
+ def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
120
+ max_distance_deg = max_distance_arcsec / 3600.0
121
+
122
+ gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
123
+ valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
124
+
125
+ kdtree = KDTree(gold_sample_radec)
126
+ distances, indices = kdtree.query(valid_sample_radec, k=1)
127
+
128
+ specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
129
+
130
+ zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
131
+
132
+ catalogue_valid['z_spec_gold'] = zs
133
+
134
+ return catalogue_valid
135
+
136
+
137
+ def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
138
+
139
+ if only_zspec:
140
+ catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
141
+ catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
142
+
143
+ self.cat_train=catalogue
144
+ f, ferr = self._extract_fluxes(catalogue)
145
+
146
+
147
+ if extinction_corr==True:
148
+ f = self._correct_extinction(catalogue,f)
149
+
150
+ if convert_colors==True:
151
+ col, colerr = self._to_colors(f, ferr)
152
+
153
+ self.phot_train = col
154
+ self.photerr_train = colerr
155
+ else:
156
+ self.phot_train = f
157
+ self.photerr_train = ferr
158
+
159
+ self.target_z_train = catalogue['z_spec_S15'].values
160
+ self.target_qz_train = catalogue['w_Q_f_S15'].values
161
+
162
+ def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
163
+
164
+ if only_zspec:
165
+ catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
166
+ catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
167
+
168
+ self.cat_test=catalogue
169
+
170
+ f, ferr = self._extract_fluxes(catalogue)
171
+
172
+
173
+ if extinction_corr==True:
174
+ f = self._correct_extinction(catalogue,f)
175
+
176
+ if convert_colors==True:
177
+ col, colerr = self._to_colors(f, ferr)
178
+ self.phot_test = col
179
+ self.photerr_test = colerr
180
+ else:
181
+ self.phot_test = f
182
+ self.photerr_test = ferr
183
+
184
+ self.target_z_test = catalogue['z_spec_S15'].values
185
+
186
+
187
+ def get_training_data(self):
188
+ return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train
189
+
190
+ def get_testing_data(self):
191
+ return self.phot_test, self.photerr_test, self.target_z_test
192
+
193
+ def get_VIS_mag(self, catalogue):
194
+ return catalogue[['MAG_VIS']].values
195
+
196
+ def plot_zdistribution(self, plot_test=False, bins=50):
197
+ _,_,specz = photoz_archive.get_training_data()
198
+ plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
199
+
200
+ if plot_test:
201
+ _,_,specz_test = photoz_archive.get_training_data()
202
+ plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
203
+
204
+
205
+ plt.xticks(fontsize=12)
206
+ plt.yticks(fontsize=12)
207
+
208
+ plt.xlabel(r'Redshift', fontsize=14)
209
+ plt.ylabel('Counts', fontsize=14)
210
+
211
+ plt.show()
insight/insight.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, dataset, TensorDataset
3
+ from torch import nn, optim
4
+ from torch.optim import lr_scheduler
5
+ import numpy as np
6
+ import pandas as pd
7
+ from astropy.io import fits
8
+ import os
9
+ from astropy.table import Table
10
+ from scipy.spatial import KDTree
11
+
12
+ class Insight_module():
13
+ """ Define class"""
14
+
15
+ def __init__(self, model):
16
+ self.model=model
17
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
20
+ input_data = torch.Tensor(input_data)
21
+ target_data = torch.Tensor(target_data)
22
+ target_weights = torch.Tensor(target_weights)
23
+
24
+ dataset = TensorDataset(input_data, target_data, target_weights)
25
+
26
+ trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
27
+ loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
28
+ loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
29
+
30
+ return loader_train, loader_val
31
+
32
+
33
+ def _loss_function(self,mean, std, logmix, true, target_weights):
34
+ log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
35
+
36
+ log_prob = torch.logsumexp(log_prob, 1)
37
+
38
+ #log_prob = log_prob * target_weights
39
+ loss = -log_prob.mean()
40
+
41
+ return loss
42
+
43
+ def _to_numpy(self,x):
44
+ return x.detach().cpu().numpy()
45
+
46
+ def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
47
+ self.model = self.model.train()
48
+ loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
49
+ optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
50
+
51
+ self.model = self.model.to(self.device)
52
+
53
+ loss_train, loss_validation = [],[]
54
+
55
+ for epoch in range(nepochs):
56
+ for input_data, target_data, target_weights in loader_train:
57
+
58
+ input_data = input_data.to(self.device)
59
+ target_data = target_data.to(self.device)
60
+ target_weights = target_weights.to(self.device)
61
+
62
+
63
+ optimizer.zero_grad()
64
+
65
+ mu, logsig, logmix_coeff = self.model(input_data)
66
+ logsig = torch.clamp(logsig,-6,2)
67
+ sig = torch.exp(logsig)
68
+
69
+
70
+ #print(mu,sig,target_data,torch.exp(logmix_coeff))
71
+
72
+ loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
73
+
74
+ loss.backward()
75
+ optimizer.step()
76
+
77
+ loss_train.append(loss.item())
78
+
79
+ for input_data, target_data, target_weights in loader_val:
80
+
81
+
82
+ input_data = input_data.to(self.device)
83
+ target_data = target_data.to(self.device)
84
+ target_weights = target_weights.to(self.device)
85
+
86
+
87
+ mu, logsig, logmix_coeff = self.model(input_data)
88
+ logsig = torch.clamp(logsig,-6,2)
89
+ sig = torch.exp(logsig)
90
+
91
+ loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
92
+ loss_validation.append(loss_val.item())
93
+
94
+ print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
95
+
96
+ self.loss_train=loss_train
97
+ self.loss_validation=loss_validation
98
+
99
+
100
+ def get_photoz(self,input_data, target_data):
101
+ self.model = self.model.eval()
102
+ self.model = self.model.to(self.device)
103
+
104
+ input_data = input_data.to(self.device)
105
+ target_data = target_data.to(self.device)
106
+
107
+ for ii in range(len(input_data)):
108
+
109
+ mu, logsig, logmix_coeff = self.model(input_data)
110
+ logsig = torch.clamp(logsig,-6,2)
111
+ sig = torch.exp(logsig)
112
+
113
+ mix_coeff = torch.exp(logmix_coeff)
114
+
115
+ z = (mix_coeff * mu).sum(1)
116
+ zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
117
+
118
+
119
+ return self._to_numpy(z),self._to_numpy(zerr)
120
+
121
+
122
+ #return model
123
+
124
+ def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
125
+ bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
126
+ ydata,xlab = [],[]
127
+
128
+
129
+ for k in range(len(bin_edges)-1):
130
+ edge_min = bin_edges[k]
131
+ edge_max = bin_edges[k+1]
132
+
133
+ mean_mag = (edge_max + edge_min) / 2
134
+
135
+ if type_bin=='bin':
136
+ df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
137
+ elif type_bin=='cum':
138
+ df_plot = df_test[(df_test.imag < edge_max)]
139
+ else:
140
+ raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
141
+
142
+
143
+ xlab.append(mean_mag)
144
+ if metric=='sig68':
145
+ ydata.append(sigma68(df_plot.zwerr))
146
+ elif metric=='bias':
147
+ ydata.append(np.mean(df_plot.zwerr))
148
+ elif metric=='nmad':
149
+ ydata.append(nmad(df_plot.zwerr))
150
+ elif metric=='outliers':
151
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
152
+
153
+ plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
154
+ plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
155
+ plt.xlabel(f'{xvariable}', fontsize = 16)
156
+
157
+ plt.xticks(fontsize = 14)
158
+ plt.yticks(fontsize = 14)
159
+
160
+ plt.grid(False)
161
+
162
+ plt.show()
163
+
164
+
165
+
166
+
insight/insight_arch.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, optim
2
+ import torch
3
+ class Photoz_network(nn.Module):
4
+ def __init__(self, num_gauss=10, dropout_prob=0):
5
+ super(Photoz_network, self).__init__()
6
+
7
+ self.features = nn.Sequential(
8
+ nn.Linear(6, 10),
9
+ nn.Dropout(dropout_prob),
10
+ nn.ReLU(),
11
+ nn.Linear(10, 30),
12
+ nn.Dropout(dropout_prob),
13
+ nn.ReLU(),
14
+ nn.Linear(30, 50),
15
+ nn.Dropout(dropout_prob),
16
+ nn.ReLU(),
17
+ nn.Linear(50, 70),
18
+ nn.Dropout(dropout_prob),
19
+ nn.ReLU(),
20
+ nn.Linear(70, 100)
21
+ )
22
+
23
+ self.measure_mu = nn.Sequential(
24
+ nn.Linear(100, 80),
25
+ nn.Dropout(dropout_prob),
26
+ nn.ReLU(),
27
+ nn.Linear(80, 70),
28
+ nn.Dropout(dropout_prob),
29
+ nn.ReLU(),
30
+ nn.Linear(70, 60),
31
+ nn.Dropout(dropout_prob),
32
+ nn.ReLU(),
33
+ nn.Linear(60, 50),
34
+ nn.Dropout(dropout_prob),
35
+ nn.ReLU(),
36
+ nn.Linear(50, num_gauss)
37
+ )
38
+
39
+ self.measure_coeffs = nn.Sequential(
40
+ nn.Linear(100, 80),
41
+ nn.Dropout(dropout_prob),
42
+ nn.ReLU(),
43
+ nn.Linear(80, 70),
44
+ nn.Dropout(dropout_prob),
45
+ nn.ReLU(),
46
+ nn.Linear(70, 60),
47
+ nn.Dropout(dropout_prob),
48
+ nn.ReLU(),
49
+ nn.Linear(60, 50),
50
+ nn.Dropout(dropout_prob),
51
+ nn.ReLU(),
52
+ nn.Linear(50, num_gauss)
53
+ )
54
+
55
+ self.measure_sigma = nn.Sequential(
56
+ nn.Linear(100, 80),
57
+ nn.Dropout(dropout_prob),
58
+ nn.ReLU(),
59
+ nn.Linear(80, 70),
60
+ nn.Dropout(dropout_prob),
61
+ nn.ReLU(),
62
+ nn.Linear(70, 60),
63
+ nn.Dropout(dropout_prob),
64
+ nn.ReLU(),
65
+ nn.Linear(60, 50),
66
+ nn.Dropout(dropout_prob),
67
+ nn.ReLU(),
68
+ nn.Linear(50, num_gauss)
69
+ )
70
+
71
+ def forward(self, x):
72
+ f = self.features(x)
73
+ mu = self.measure_mu(f)
74
+ sigma = self.measure_sigma(f)
75
+ logmix_coeff = self.measure_coeffs(f)
76
+
77
+ logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
78
+
79
+ return mu, sigma, logmix_coeff
80
+
81
+
insight/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from scipy import stats
5
+
6
+ def nmad(data):
7
+ return 1.4826 * np.median(np.abs(data - np.median(data)))
8
+
9
+ def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
+
11
+
12
+ def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
13
+ bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
14
+ ydata,xlab = [],[]
15
+
16
+
17
+ for k in range(len(bin_edges)-1):
18
+ edge_min = bin_edges[k]
19
+ edge_max = bin_edges[k+1]
20
+
21
+ mean_mag = (edge_max + edge_min) / 2
22
+
23
+ if type_bin=='bin':
24
+ df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
25
+ elif type_bin=='cum':
26
+ df_plot = df_test[(df_test.imag < edge_max)]
27
+ else:
28
+ raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
29
+
30
+
31
+ xlab.append(mean_mag)
32
+ if metric=='sig68':
33
+ ydata.append(sigma68(df_plot.zwerr))
34
+ elif metric=='bias':
35
+ ydata.append(np.mean(df_plot.zwerr))
36
+ elif metric=='nmad':
37
+ ydata.append(nmad(df_plot.zwerr))
38
+ elif metric=='outliers':
39
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
40
+
41
+ plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
42
+ plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
43
+ plt.xlabel(f'{xvariable}', fontsize = 16)
44
+
45
+ plt.xticks(fontsize = 14)
46
+ plt.yticks(fontsize = 14)
47
+
48
+ plt.grid(False)
49
+
50
+ plt.show()
51
+