lauracabayol commited on
Commit
a57776c
·
1 Parent(s): e50129b

changes to enable training with any catalog

Browse files
Files changed (4) hide show
  1. notebooks/NMAD.py +10 -5
  2. pyproject.toml +1 -0
  3. temps/archive.py +57 -115
  4. temps/temps.py +3 -1
notebooks/NMAD.py CHANGED
@@ -33,6 +33,7 @@ import os
33
  from astropy.io import fits
34
  from astropy.table import Table
35
  import torch
 
36
 
37
  # %%
38
  #matplotlib settings
@@ -63,6 +64,8 @@ eval_methods=True
63
  #define here the directory containing the photometric catalogues
64
  parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
65
  modules_dir = Path('../data/models/')
 
 
66
 
67
  # %%
68
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
@@ -83,9 +86,11 @@ VISmag = cat['MAG_VIS']
83
  zsflag = cat['reliable_S15']
84
 
85
  # %%
86
- photoz_archive = Archive(path = parent_dir,only_zspec=False)
87
- f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
88
- col, colerr = photoz_archive._to_colors(f, ferr)
 
 
89
 
90
  # %% [markdown]
91
  # ### EVALUATE USING TRAINED MODELS
@@ -97,9 +102,9 @@ if eval_methods:
97
  for il, lab in enumerate(['z','L15','DA']):
98
 
99
  nn_features = EncoderPhotometry()
100
- nn_features.load_state_dict(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
101
  nn_z = MeasureZ(num_gauss=6)
102
- nn_z.load_state_dict(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
103
 
104
  temps_module = TempsModule(nn_features, nn_z)
105
 
 
33
  from astropy.io import fits
34
  from astropy.table import Table
35
  import torch
36
+ from pathlib import Path
37
 
38
  # %%
39
  #matplotlib settings
 
64
  #define here the directory containing the photometric catalogues
65
  parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
66
  modules_dir = Path('../data/models/')
67
+ filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
68
+ filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
69
 
70
  # %%
71
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
 
86
  zsflag = cat['reliable_S15']
87
 
88
  # %%
89
+ photoz_archive = Archive(path_calib = parent_dir/filename_calib,
90
+ path_valid = parent_dir/filename_valid,
91
+ only_zspec=False)
92
+ f = photoz_archive._extract_fluxes(catalogue= cat)
93
+ col = photoz_archive._to_colors(f)
94
 
95
  # %% [markdown]
96
  # ### EVALUATE USING TRAINED MODELS
 
102
  for il, lab in enumerate(['z','L15','DA']):
103
 
104
  nn_features = EncoderPhotometry()
105
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
106
  nn_z = MeasureZ(num_gauss=6)
107
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
108
 
109
  temps_module = TempsModule(nn_features, nn_z)
110
 
pyproject.toml CHANGED
@@ -28,6 +28,7 @@ dependencies = [
28
  "pathlib",
29
  "astropy",
30
  "gradio",
 
31
  ]
32
 
33
  classifiers = [
 
28
  "pathlib",
29
  "astropy",
30
  "gradio",
31
+ "jupytext"
32
  ]
33
 
34
  classifiers = [
temps/archive.py CHANGED
@@ -1,6 +1,7 @@
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
 
4
  from scipy.spatial import KDTree
5
  from matplotlib import pyplot as plt
6
  from matplotlib import rcParams
@@ -12,37 +13,47 @@ rcParams["mathtext.fontset"] = "stix"
12
  rcParams["font.family"] = "STIXGeneral"
13
 
14
  class Archive:
15
- def __init__(self, path,
16
- aperture=2,
 
17
  drop_stars=True,
18
  clean_photometry=True,
19
  convert_colors=True,
20
  extinction_corr=True,
21
  only_zspec=True,
22
- all_apertures=False,
23
- target_test='specz', flags_kept=[3, 3.1, 3.4, 3.5, 4]):
 
 
 
 
24
 
25
-
26
  logger.info("Starting archive")
27
- self.aperture = aperture
28
- self.all_apertures = all_apertures
29
  self.flags_kept = flags_kept
 
 
 
30
 
31
- filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
32
- filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
33
-
34
- # Use Path for file handling
35
- path_calib = Path(path) / filename_calib
36
- path_valid = Path(path) / filename_valid
37
-
38
- # Open the calibration FITS file
39
- with fits.open(path_calib) as hdu_list:
40
- cat = Table(hdu_list[1].data).to_pandas()
41
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
 
 
 
 
 
 
 
42
 
43
- # Open the validation FITS file
44
- with fits.open(path_valid) as hdu_list:
45
- cat_test = Table(hdu_list[1].data).to_pandas()
46
 
47
  # Store the catalogs for later use
48
  self.cat = cat
@@ -85,57 +96,18 @@ class Archive:
85
 
86
 
87
  def _extract_fluxes(self,catalogue):
88
- if self.all_apertures:
89
- columns_f = [f'FLUX_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
90
- columns_ferr = [f'FLUXERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
91
- else:
92
- columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
93
- columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
94
-
95
- f = catalogue[columns_f].values
96
- ferr = catalogue[columns_ferr].values
97
- return f, ferr
98
-
99
- def _extract_magnitudes(self,catalogue):
100
- if self.all_apertures:
101
- columns_m = [f'MAG_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
102
- columns_merr = [f'MAGERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
103
- else:
104
- columns_m = [f'MAG_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
105
- columns_merr = [f'MAGERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
106
 
107
- m = catalogue[columns_m].values
108
- merr = catalogue[columns_merr].values
109
- return m, merr
110
-
111
- def _to_colors(self, flux, fluxerr):
112
  """ Convert fluxes to colors"""
113
-
114
- if self.all_apertures:
115
-
116
- for a in range(3):
117
- lim1 = 7*a
118
- lim2 = 7*(a+1)
119
- c = flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]
120
- cerr = np.sqrt((fluxerr[:,lim1:(lim2-1)]/ flux[:,(lim1+1):lim2])**2 + (flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]**2)**2 * fluxerr[:,(lim1+1):lim2]**2)
121
-
122
- if a==0:
123
- color = c
124
- color_err = cerr
125
- else:
126
- color = np.concatenate((color,c),axis=1)
127
- color_err = np.concatenate((color_err,cerr),axis=1)
128
-
129
- else:
130
- color = flux[:,:-1] / flux[:,1:]
131
-
132
- color_err = np.sqrt((fluxerr[:,:-1]/ flux[:,1:])**2 + (flux[:,:-1] / flux[:,1:]**2)**2 * fluxerr[:,1:]**2)
133
- return color,color_err
134
 
135
  def _set_combiend_target(self, catalogue):
136
- catalogue['target_z'] = catalogue.apply(lambda row: row['z_spec_S15']
137
- if row['z_spec_S15'] > 0
138
- else row['photo_z_L15'], axis=1)
139
 
140
  return catalogue
141
 
@@ -148,13 +120,7 @@ class Archive:
148
 
149
  def _correct_extinction(self,catalogue, f, return_ext_corr=False):
150
  """Corrects for extinction"""
151
- ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
152
- if self.all_apertures:
153
- ext_correction = catalogue[ext_correction_cols].values
154
- ext_correction = np.concatenate((ext_correction,ext_correction,ext_correction),axis=1)
155
- else:
156
- ext_correction = catalogue[ext_correction_cols].values
157
-
158
  f = f * ext_correction
159
  if return_ext_corr:
160
  return f, ext_correction
@@ -164,14 +130,14 @@ class Archive:
164
  def _select_only_zspec(self,catalogue,cat_flag=None):
165
  """Selects only galaxies with spectroscopic redshift"""
166
  if cat_flag=='Calib':
167
- catalogue = catalogue[catalogue.z_spec_S15>0]
168
  elif cat_flag=='Valid':
169
- catalogue = catalogue[catalogue.z_spec_S15>0]
170
  return catalogue
171
 
172
  def _exclude_only_zspec(self,catalogue):
173
  """Selects only galaxies without spectroscopic redshift"""
174
- catalogue = catalogue[(catalogue.z_spec_S15<0)&(catalogue.photo_z_L15>0)&(catalogue.photo_z_L15<4)]
175
  return catalogue
176
 
177
  def _select_L15_sample(self,catalogue):
@@ -187,7 +153,7 @@ class Archive:
187
  if cat_flag=='Calib':
188
  catalogue = catalogue[catalogue.target_z>0]
189
  elif cat_flag=='Valid':
190
- catalogue = catalogue[catalogue.z_spec_S15>0]
191
  return catalogue
192
 
193
  def _clean_zspec_sample(self,catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
@@ -222,7 +188,7 @@ class Archive:
222
  def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
223
 
224
  cat_da = self._exclude_only_zspec(catalogue_da)
225
- target_z_train_DA = cat_da['photo_z_L15'].values
226
 
227
 
228
  if only_zspec:
@@ -235,11 +201,10 @@ class Archive:
235
 
236
 
237
  self.cat_train=catalogue
238
- f, ferr = self._extract_fluxes(catalogue)
239
-
240
- f_DA, ferr_DA = self._extract_fluxes(cat_da)
241
  idx = np.random.randint(0, len(f_DA), len(f))
242
- f_DA, ferr_DA = f_DA[idx], ferr_DA[idx]
243
  target_z_train_DA = target_z_train_DA[idx]
244
  self.target_z_train_DA = target_z_train_DA
245
 
@@ -250,21 +215,17 @@ class Archive:
250
 
251
  if convert_colors==True:
252
  logger.info("Converting to colors")
253
- col, colerr = self._to_colors(f, ferr)
254
- col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
255
 
256
  self.phot_train = col
257
- self.photerr_train = colerr
258
  self.phot_train_DA = col_DA
259
- self.photerr_train_DA = colerr_DA
260
  else:
261
  self.phot_train = f
262
- self.photerr_train = ferr
263
  self.phot_train_DA = f_DA
264
- self.photerr_train_DA = ferr_DA
265
 
266
  if only_zspec==True:
267
- self.target_z_train = catalogue['z_spec_S15'].values
268
  else:
269
  self.target_z_train = catalogue['target_z'].values
270
 
@@ -275,7 +236,7 @@ class Archive:
275
  if target=='specz':
276
  catalogue = self._select_only_zspec(catalogue, cat_flag='Valid')
277
  catalogue = self._clean_zspec_sample(catalogue)
278
- self.target_z_test = catalogue['z_spec_S15'].values
279
 
280
  elif target=='L15':
281
  catalogue = self._select_L15_sample(catalogue)
@@ -284,45 +245,26 @@ class Archive:
284
 
285
  self.cat_test=catalogue
286
 
287
- f, ferr = self._extract_fluxes(catalogue)
288
 
289
  if extinction_corr==True:
290
  f = self._correct_extinction(catalogue,f)
291
 
292
  if convert_colors==True:
293
- col, colerr = self._to_colors(f, ferr)
294
  self.phot_test = col
295
- self.photerr_test = colerr
296
  else:
297
  self.phot_test = f
298
- self.photerr_test = ferr
299
 
300
 
301
  self.VIS_mag_test = catalogue['MAG_VIS'].values
302
 
303
 
304
  def get_training_data(self):
305
- return self.phot_train, self.photerr_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.photerr_train_DA, self.target_z_train_DA
306
 
307
  def get_testing_data(self):
308
- return self.phot_test, self.photerr_test, self.target_z_test, self.VIS_mag_test
309
 
310
  def get_VIS_mag(self, catalogue):
311
  return catalogue[['MAG_VIS']].values
312
-
313
- def plot_zdistribution(self, plot_test=False, bins=50):
314
- _,_,specz = photoz_archive.get_training_data()
315
- plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
316
-
317
- if plot_test:
318
- _,_,specz_test = photoz_archive.get_training_data()
319
- plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
320
-
321
-
322
- plt.xticks(fontsize=12)
323
- plt.yticks(fontsize=12)
324
-
325
- plt.xlabel(r'Redshift', fontsize=14)
326
- plt.ylabel('Counts', fontsize=14)
327
-
328
- plt.show()
 
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
4
+ from astropy.table import Table
5
  from scipy.spatial import KDTree
6
  from matplotlib import pyplot as plt
7
  from matplotlib import rcParams
 
13
  rcParams["font.family"] = "STIXGeneral"
14
 
15
  class Archive:
16
+ def __init__(self,
17
+ path_calib,
18
+ path_valid=None,
19
  drop_stars=True,
20
  clean_photometry=True,
21
  convert_colors=True,
22
  extinction_corr=True,
23
  only_zspec=True,
24
+ columns_photometry = ['FLUX_G_2','FLUX_R_2','FLUX_I_2','FLUX_Z_2','FLUX_Y_2','FLUX_J_2','FLUX_H_2'],
25
+ columns_ebv = ['EB_V_corr_FLUX_G','EB_V_corr_FLUX_R','EB_V_corr_FLUX_I','EB_V_corr_FLUX_Z','EB_V_corr_FLUX_Y','EB_V_corr_FLUX_J','EB_V_corr_FLUX_H'],
26
+ photoz_name="photo_z_L15",
27
+ specz_name="z_spec_S15",
28
+ target_test='specz',
29
+ flags_kept=[3, 3.1, 3.4, 3.5, 4]):
30
 
 
31
  logger.info("Starting archive")
 
 
32
  self.flags_kept = flags_kept
33
+ self.columns_photometry=columns_photometry
34
+ self.columns_ebv=columns_ebv
35
+
36
 
37
+ if path_calib.suffix == ".fits":
38
+ with fits.open(path_calib) as hdu_list:
39
+ cat = Table(hdu_list[1].data).to_pandas()
40
+ if path_valid != None:
41
+ with fits.open(path_valid) as hdu_list:
42
+ cat_test = Table(hdu_list[1].data).to_pandas()
43
+
44
+ elif path_calib.suffix == ".csv":
45
+ cat = pd.read_csv(path_calib)
46
+ if path_valid != None:
47
+ cat_test = pd.read_csv(path_valid)
48
+ else:
49
+ raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
50
+
51
+ cat = cat.rename(columns ={f"{specz_name}":"specz",
52
+ f"{photoz_name}":"photo_z"})
53
+ cat_test = cat_test.rename(columns ={f"{specz_name}":"specz",
54
+ f"{photoz_name}":"photo_z"})
55
 
56
+ cat = cat[(cat['specz'] > 0) | (cat['photo_z'] > 0)]
 
 
57
 
58
  # Store the catalogs for later use
59
  self.cat = cat
 
96
 
97
 
98
  def _extract_fluxes(self,catalogue):
99
+ f = catalogue[self.columns_photometry].values
100
+ return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ def _to_colors(self, flux):
 
 
 
 
103
  """ Convert fluxes to colors"""
104
+ color = flux[:,:-1] / flux[:,1:]
105
+ return color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def _set_combiend_target(self, catalogue):
108
+ catalogue['target_z'] = catalogue.apply(lambda row: row['specz']
109
+ if row['specz'] > 0
110
+ else row['photo_z'], axis=1)
111
 
112
  return catalogue
113
 
 
120
 
121
  def _correct_extinction(self,catalogue, f, return_ext_corr=False):
122
  """Corrects for extinction"""
123
+ ext_correction = catalogue[self.columns_ebv].values
 
 
 
 
 
 
124
  f = f * ext_correction
125
  if return_ext_corr:
126
  return f, ext_correction
 
130
  def _select_only_zspec(self,catalogue,cat_flag=None):
131
  """Selects only galaxies with spectroscopic redshift"""
132
  if cat_flag=='Calib':
133
+ catalogue = catalogue[catalogue.specz>0]
134
  elif cat_flag=='Valid':
135
+ catalogue = catalogue[catalogue.specz>0]
136
  return catalogue
137
 
138
  def _exclude_only_zspec(self,catalogue):
139
  """Selects only galaxies without spectroscopic redshift"""
140
+ catalogue = catalogue[(catalogue.specz<0)&(catalogue.photo_z>0)&(catalogue.photo_z<4)]
141
  return catalogue
142
 
143
  def _select_L15_sample(self,catalogue):
 
153
  if cat_flag=='Calib':
154
  catalogue = catalogue[catalogue.target_z>0]
155
  elif cat_flag=='Valid':
156
+ catalogue = catalogue[catalogue.specz>0]
157
  return catalogue
158
 
159
  def _clean_zspec_sample(self,catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
 
188
  def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
189
 
190
  cat_da = self._exclude_only_zspec(catalogue_da)
191
+ target_z_train_DA = cat_da['photo_z'].values
192
 
193
 
194
  if only_zspec:
 
201
 
202
 
203
  self.cat_train=catalogue
204
+ f = self._extract_fluxes(catalogue)
205
+ f_DA = self._extract_fluxes(cat_da)
 
206
  idx = np.random.randint(0, len(f_DA), len(f))
207
+ f_DA = f_DA[idx]
208
  target_z_train_DA = target_z_train_DA[idx]
209
  self.target_z_train_DA = target_z_train_DA
210
 
 
215
 
216
  if convert_colors==True:
217
  logger.info("Converting to colors")
218
+ col = self._to_colors(f)
219
+ col_DA = self._to_colors(f_DA)
220
 
221
  self.phot_train = col
 
222
  self.phot_train_DA = col_DA
 
223
  else:
224
  self.phot_train = f
 
225
  self.phot_train_DA = f_DA
 
226
 
227
  if only_zspec==True:
228
+ self.target_z_train = catalogue['specz'].values
229
  else:
230
  self.target_z_train = catalogue['target_z'].values
231
 
 
236
  if target=='specz':
237
  catalogue = self._select_only_zspec(catalogue, cat_flag='Valid')
238
  catalogue = self._clean_zspec_sample(catalogue)
239
+ self.target_z_test = catalogue['specz'].values
240
 
241
  elif target=='L15':
242
  catalogue = self._select_L15_sample(catalogue)
 
245
 
246
  self.cat_test=catalogue
247
 
248
+ f = self._extract_fluxes(catalogue)
249
 
250
  if extinction_corr==True:
251
  f = self._correct_extinction(catalogue,f)
252
 
253
  if convert_colors==True:
254
+ col = self._to_colors(f)
255
  self.phot_test = col
 
256
  else:
257
  self.phot_test = f
 
258
 
259
 
260
  self.VIS_mag_test = catalogue['MAG_VIS'].values
261
 
262
 
263
  def get_training_data(self):
264
+ return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
265
 
266
  def get_testing_data(self):
267
+ return self.phot_test, self.target_z_test, self.VIS_mag_test
268
 
269
  def get_VIS_mag(self, catalogue):
270
  return catalogue[['MAG_VIS']].values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
temps/temps.py CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader, TensorDataset
5
  from torch.optim import lr_scheduler
6
  from loguru import logger
7
  import pandas as pd
 
 
8
  from tqdm import tqdm # Import tqdm for progress bars
9
 
10
  from temps.utils import maximum_mean_discrepancy
@@ -47,7 +49,7 @@ class TempsModule:
47
  dataset = TensorDataset(input_data, input_data_da, target_data)
48
  train_dataset, val_dataset = torch.utils.data.random_split(
49
  dataset,
50
- [int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
51
  )
52
  loader_train = DataLoader(
53
  train_dataset, batch_size=self.batch_size, shuffle=True
 
5
  from torch.optim import lr_scheduler
6
  from loguru import logger
7
  import pandas as pd
8
+ from scipy.stats import norm
9
+
10
  from tqdm import tqdm # Import tqdm for progress bars
11
 
12
  from temps.utils import maximum_mean_discrepancy
 
49
  dataset = TensorDataset(input_data, input_data_da, target_data)
50
  train_dataset, val_dataset = torch.utils.data.random_split(
51
  dataset,
52
+ [int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)+1],
53
  )
54
  loader_train = DataLoader(
55
  train_dataset, batch_size=self.batch_size, shuffle=True