lauracabayol commited on
Commit
546e741
·
unverified ·
2 Parent(s): f313d2c c9354dd

Merge pull request #3 from lauracabayol/improve_code

Browse files
Files changed (7) hide show
  1. notebooks/NMAD.py +3 -1
  2. pyproject.toml +1 -0
  3. temps/archive.py +230 -158
  4. temps/plots.py +260 -225
  5. temps/temps.py +207 -151
  6. temps/temps_arch.py +59 -6
  7. temps/utils.py +165 -40
notebooks/NMAD.py CHANGED
@@ -61,6 +61,7 @@ eval_methods=True
61
  # ### LOAD DATA
62
 
63
  # %%
 
64
  #define here the directory containing the photometric catalogues
65
  parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
66
  modules_dir = Path('../data/models/')
@@ -68,7 +69,6 @@ filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
68
  filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
69
 
70
  # %%
71
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
72
  path_file = parent_dir / filename_valid # Creating the path to the file
73
  hdu_list = fits.open(path_file)
74
  cat = Table(hdu_list[1].data).to_pandas()
@@ -158,3 +158,5 @@ plot_photoz(df_list,
158
  save=False,
159
  samp='L15'
160
  )
 
 
 
61
  # ### LOAD DATA
62
 
63
  # %%
64
+
65
  #define here the directory containing the photometric catalogues
66
  parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
67
  modules_dir = Path('../data/models/')
 
69
  filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
70
 
71
  # %%
 
72
  path_file = parent_dir / filename_valid # Creating the path to the file
73
  hdu_list = fits.open(path_file)
74
  cat = Table(hdu_list[1].data).to_pandas()
 
158
  save=False,
159
  samp='L15'
160
  )
161
+
162
+ # %%
pyproject.toml CHANGED
@@ -30,6 +30,7 @@ dependencies = [
30
  "gradio",
31
  "jupytext",
32
  "mkdocs",
 
33
  ]
34
 
35
  classifiers = [
 
30
  "gradio",
31
  "jupytext",
32
  "mkdocs",
33
+ "typing"
34
  ]
35
 
36
  classifiers = [
temps/archive.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
@@ -5,199 +6,263 @@ from astropy.table import Table
5
  from scipy.spatial import KDTree
6
  from matplotlib import pyplot as plt
7
  from matplotlib import rcParams
8
- from pathlib import Path
9
  from loguru import logger
 
10
 
11
-
12
  rcParams["mathtext.fontset"] = "stix"
13
  rcParams["font.family"] = "STIXGeneral"
14
 
 
15
  class Archive:
16
- def __init__(self,
17
- path_calib,
18
- path_valid=None,
19
- drop_stars=True,
20
- clean_photometry=True,
21
- convert_colors=True,
22
- extinction_corr=True,
23
- only_zspec=True,
24
- columns_photometry = ['FLUX_G_2','FLUX_R_2','FLUX_I_2','FLUX_Z_2','FLUX_Y_2','FLUX_J_2','FLUX_H_2'],
25
- columns_ebv = ['EB_V_corr_FLUX_G','EB_V_corr_FLUX_R','EB_V_corr_FLUX_I','EB_V_corr_FLUX_Z','EB_V_corr_FLUX_Y','EB_V_corr_FLUX_J','EB_V_corr_FLUX_H'],
26
- photoz_name="photo_z_L15",
27
- specz_name="z_spec_S15",
28
- target_test='specz',
29
- flags_kept=[3, 3.1, 3.4, 3.5, 4]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
31
  logger.info("Starting archive")
32
- self.flags_kept = flags_kept
33
- self.columns_photometry=columns_photometry
34
- self.columns_ebv=columns_ebv
35
 
36
-
37
- if path_calib.suffix == ".fits":
38
- with fits.open(path_calib) as hdu_list:
39
- cat = Table(hdu_list[1].data).to_pandas()
40
- if path_valid != None:
41
- with fits.open(path_valid) as hdu_list:
42
- cat_test = Table(hdu_list[1].data).to_pandas()
43
-
44
- elif path_calib.suffix == ".csv":
45
- cat = pd.read_csv(path_calib)
46
- if path_valid != None:
47
- cat_test = pd.read_csv(path_valid)
48
  else:
49
  raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
50
 
51
- cat = cat.rename(columns ={f"{specz_name}":"specz",
52
- f"{photoz_name}":"photo_z"})
53
- cat_test = cat_test.rename(columns ={f"{specz_name}":"specz",
54
- f"{photoz_name}":"photo_z"})
55
-
56
- cat = cat[(cat['specz'] > 0) | (cat['photo_z'] > 0)]
57
-
58
- # Store the catalogs for later use
59
- self.cat = cat
60
- self.cat_test = cat_test
61
-
62
-
63
- if drop_stars==True:
64
- logger.info("dropping stars...")
65
- cat = cat[cat.mu_class_L07==1]
66
- cat_test = cat_test[cat_test.mu_class_L07==1]
67
-
68
- if clean_photometry==True:
69
- logger.info("cleaning stars...")
70
- cat = self._clean_photometry(cat)
71
- cat_test = self._clean_photometry(cat_test)
72
-
73
-
74
- cat = self._set_combiend_target(cat)
75
- cat_test = self._set_combiend_target(cat_test)
76
-
77
-
78
-
79
- cat = cat[cat.MAG_VIS<25]
80
- cat_test = cat_test[cat_test.MAG_VIS<25]
81
-
82
- cat = cat[cat.target_z<5]
83
- cat_test = cat_test[cat_test.target_z<5]
84
-
85
-
86
-
87
- self._set_training_data(cat,
88
- cat_test,
89
- only_zspec=only_zspec,
90
- extinction_corr=extinction_corr,
91
- convert_colors=convert_colors)
92
- self._set_testing_data(cat_test,
93
- target=target_test,
94
- extinction_corr=extinction_corr,
95
- convert_colors=convert_colors)
96
-
97
-
98
- def _extract_fluxes(self,catalogue):
 
 
 
 
 
 
99
  f = catalogue[self.columns_photometry].values
100
  return f
101
 
102
- def _to_colors(self, flux):
103
- """ Convert fluxes to colors"""
104
- color = flux[:,:-1] / flux[:,1:]
 
 
 
 
 
 
 
 
105
  return color
106
-
107
- def _set_combiend_target(self, catalogue):
108
- catalogue['target_z'] = catalogue.apply(lambda row: row['specz']
109
- if row['specz'] > 0
110
- else row['photo_z'], axis=1)
111
-
 
 
 
 
 
 
 
 
112
  return catalogue
113
 
114
-
115
- def _clean_photometry(self,catalogue):
116
- """ Drops all object with FLAG_PHOT!=0"""
117
- catalogue = catalogue[catalogue['FLAG_PHOT']==0]
118
-
 
 
 
 
 
 
119
  return catalogue
120
-
121
- def _correct_extinction(self,catalogue, f, return_ext_corr=False):
122
- """Corrects for extinction"""
 
 
 
 
 
 
 
 
 
 
 
123
  ext_correction = catalogue[self.columns_ebv].values
124
  f = f * ext_correction
125
  if return_ext_corr:
126
  return f, ext_correction
127
  else:
128
  return f
129
-
130
- def _select_only_zspec(self,catalogue,cat_flag=None):
131
- """Selects only galaxies with spectroscopic redshift"""
132
- if cat_flag=='Calib':
133
- catalogue = catalogue[catalogue.specz>0]
134
- elif cat_flag=='Valid':
135
- catalogue = catalogue[catalogue.specz>0]
136
- return catalogue
137
-
138
- def _exclude_only_zspec(self,catalogue):
139
- """Selects only galaxies without spectroscopic redshift"""
140
- catalogue = catalogue[(catalogue.specz<0)&(catalogue.photo_z>0)&(catalogue.photo_z<4)]
141
- return catalogue
142
-
143
- def _select_L15_sample(self,catalogue):
144
- """Selects only galaxies withoutidx spectroscopic redshift"""
145
- catalogue = catalogue[(catalogue.target_z>0)]
146
- catalogue = catalogue[(catalogue.target_z<4)]
147
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
 
 
149
  return catalogue
150
-
151
- def _take_zspec_and_photoz(self,catalogue,cat_flag=None):
 
 
152
  """Selects only galaxies with spectroscopic redshift"""
153
  if cat_flag=='Calib':
154
  catalogue = catalogue[catalogue.target_z>0]
155
  elif cat_flag=='Valid':
156
  catalogue = catalogue[catalogue.specz>0]
157
  return catalogue
 
 
 
 
158
 
159
- def _clean_zspec_sample(self,catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
160
- #[ 2.5, 3.5, 4. , 1.5, 1.1, 13.5, 9. , 3. , 2.1, 9.5, 3.1,
161
- #1. , 9.1, 2. , 9.3, 1.4, 3.4, 11.5, 2.4, 13. , 14. , 12.1,
162
- #12.5, 13.1, 9.4, 11.1]
163
-
164
- catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
165
 
 
 
 
 
 
 
166
  return catalogue
167
-
168
 
 
 
 
 
169
 
170
- def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
171
- max_distance_deg = max_distance_arcsec / 3600.0
172
-
173
- gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
174
- valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
175
-
176
- kdtree = KDTree(gold_sample_radec)
177
- distances, indices = kdtree.query(valid_sample_radec, k=1)
178
-
179
- specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
180
-
181
- zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
182
-
183
- catalogue_valid['z_spec_gold'] = zs
184
 
185
- return catalogue_valid
 
186
 
187
-
188
- def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
 
 
 
189
 
190
- cat_da = self._exclude_only_zspec(catalogue_da)
 
 
 
 
 
 
 
 
191
  target_z_train_DA = cat_da['photo_z'].values
192
 
193
 
194
  if only_zspec:
195
  logger.info("Selecting only galaxies with spectroscopic redshift")
196
- catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
197
- catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
198
  else:
199
  logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
200
- catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
201
 
202
 
203
  self.cat_train=catalogue
@@ -230,25 +295,32 @@ class Archive:
230
  self.target_z_train = catalogue['target_z'].values
231
 
232
  self.VIS_mag_train = catalogue['MAG_VIS'].values
233
-
234
- def _set_testing_data(self,catalogue, target='specz', extinction_corr=True, convert_colors=True):
 
 
 
 
 
 
 
235
 
236
  if target=='specz':
237
- catalogue = self._select_only_zspec(catalogue, cat_flag='Valid')
238
- catalogue = self._clean_zspec_sample(catalogue)
239
- self.target_z_test = catalogue['specz'].values
240
 
241
  elif target=='L15':
242
- catalogue = self._select_L15_sample(catalogue)
243
- self.target_z_test = catalogue['target_z'].values
244
 
245
 
246
- self.cat_test=catalogue
247
 
248
- f = self._extract_fluxes(catalogue)
249
 
250
  if extinction_corr==True:
251
- f = self._correct_extinction(catalogue,f)
252
 
253
  if convert_colors==True:
254
  col = self._to_colors(f)
@@ -257,9 +329,9 @@ class Archive:
257
  self.phot_test = f
258
 
259
 
260
- self.VIS_mag_test = catalogue['MAG_VIS'].values
261
-
262
-
263
  def get_training_data(self):
264
  return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
265
 
@@ -267,4 +339,4 @@ class Archive:
267
  return self.phot_test, self.target_z_test, self.VIS_mag_test
268
 
269
  def get_VIS_mag(self, catalogue):
270
- return catalogue[['MAG_VIS']].values
 
1
+ from dataclasses import dataclass, field
2
  import numpy as np
3
  import pandas as pd
4
  from astropy.io import fits
 
6
  from scipy.spatial import KDTree
7
  from matplotlib import pyplot as plt
8
  from matplotlib import rcParams
9
+ from pathlib import Path
10
  from loguru import logger
11
+ from typing import Optional, Tuple, Union, List
12
 
13
+ # Set matplotlib configuration
14
  rcParams["mathtext.fontset"] = "stix"
15
  rcParams["font.family"] = "STIXGeneral"
16
 
17
+ @dataclass
18
  class Archive:
19
+ path_calib: Path
20
+ path_valid: Optional[Path] = None
21
+ drop_stars: bool = True
22
+ clean_photometry: bool = True
23
+ convert_colors: bool = True
24
+ extinction_corr: bool = True
25
+ only_zspec: bool = True
26
+ columns_photometry: List[str] = field(default_factory=lambda: [
27
+ "FLUX_G_2",
28
+ "FLUX_R_2",
29
+ "FLUX_I_2",
30
+ "FLUX_Z_2",
31
+ "FLUX_Y_2",
32
+ "FLUX_J_2",
33
+ "FLUX_H_2",
34
+ ])
35
+ columns_ebv: List[str] = field(default_factory=lambda: [
36
+ "EB_V_corr_FLUX_G",
37
+ "EB_V_corr_FLUX_R",
38
+ "EB_V_corr_FLUX_I",
39
+ "EB_V_corr_FLUX_Z",
40
+ "EB_V_corr_FLUX_Y",
41
+ "EB_V_corr_FLUX_J",
42
+ "EB_V_corr_FLUX_H",
43
+ ])
44
+ photoz_name: str = "photo_z_L15"
45
+ specz_name: str = "z_spec_S15"
46
+ target_test: str = "specz"
47
+ flags_kept: List[float] = field(default_factory=lambda: [3, 3.1, 3.4, 3.5, 4])
48
 
49
+ def __post_init__(self):
50
  logger.info("Starting archive")
 
 
 
51
 
52
+ # Load data based on the file format
53
+ if self.path_calib.suffix == ".fits":
54
+ with fits.open(self.path_calib) as hdu_list:
55
+ self.cat = Table(hdu_list[1].data).to_pandas()
56
+ if self.path_valid is not None:
57
+ with fits.open(self.path_valid) as hdu_list:
58
+ self.cat_test = Table(hdu_list[1].data).to_pandas()
59
+
60
+ elif self.path_calib.suffix == ".csv":
61
+ self.cat = pd.read_csv(self.path_calib)
62
+ if self.path_valid is not None:
63
+ self.cat_test = pd.read_csv(self.path_valid)
64
  else:
65
  raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
66
 
67
+ self.cat = self.cat.rename(
68
+ columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
69
+ )
70
+ self.cat_test = self.cat_test.rename(
71
+ columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
72
+ )
73
+
74
+ self.cat = self.cat[(self.cat["specz"] > 0) | (self.cat["photo_z"] > 0)]
75
+
76
+ # Apply operations based on the initialized parameters
77
+ if self.drop_stars:
78
+ logger.info("Dropping stars...")
79
+ self.cat = self.cat[self.cat.mu_class_L07 == 1]
80
+ self.cat_test = self.cat_test[self.cat_test.mu_class_L07 == 1]
81
+
82
+ if self.clean_photometry:
83
+ logger.info("Cleaning photometry...")
84
+ self.cat = self._clean_photometry(catalogue=self.cat)
85
+ self.cat_test = self._clean_photometry(catalogue=self.cat_test)
86
+
87
+ self.cat = self._set_combined_target(self.cat)
88
+ self.cat_test = self._set_combined_target(self.cat_test)
89
+
90
+ # Apply magnitude and redshift cuts
91
+ self.cat = self.cat[self.cat.MAG_VIS < 25]
92
+ self.cat_test = self.cat_test[self.cat_test.MAG_VIS < 25]
93
+
94
+ self.cat = self.cat[self.cat.target_z < 5]
95
+ self.cat_test = self.cat_test[self.cat_test.target_z < 5]
96
+
97
+ self._set_training_data(
98
+ self.cat,
99
+ self.cat_test,
100
+ only_zspec=self.only_zspec,
101
+ extinction_corr=self.extinction_corr,
102
+ convert_colors=self.convert_colors,
103
+ )
104
+ self._set_testing_data(
105
+ self.cat_test,
106
+ target=self.target_test,
107
+ extinction_corr=self.extinction_corr,
108
+ convert_colors=self.convert_colors,
109
+ )
110
+
111
+
112
+ def _extract_fluxes(self, catalogue: pd.DataFrame) -> np.ndarray:
113
+ """Extract fluxes from the given catalogue.
114
+
115
+ Args:
116
+ catalogue (pd.DataFrame): The input catalogue.
117
+
118
+ Returns:
119
+ np.ndarray: An array of fluxes.
120
+ """
121
  f = catalogue[self.columns_photometry].values
122
  return f
123
 
124
+ @staticmethod
125
+ def _to_colors(flux: np.ndarray) -> np.ndarray:
126
+ """Convert fluxes to colors.
127
+
128
+ Args:
129
+ flux (np.ndarray): The input fluxes.
130
+
131
+ Returns:
132
+ np.ndarray: An array of colors.
133
+ """
134
+ color = flux[:, :-1] / flux[:, 1:]
135
  return color
136
+
137
+ @staticmethod
138
+ def _set_combined_target(catalogue: pd.DataFrame) -> pd.DataFrame:
139
+ """Set the combined target redshift based on available data.
140
+
141
+ Args:
142
+ catalogue (pd.DataFrame): The input catalogue.
143
+
144
+ Returns:
145
+ pd.DataFrame: Updated catalogue with the combined target redshift.
146
+ """
147
+ catalogue["target_z"] = catalogue.apply(
148
+ lambda row: row["specz"] if row["specz"] > 0 else row["photo_z"], axis=1
149
+ )
150
  return catalogue
151
 
152
+ @staticmethod
153
+ def _clean_photometry(catalogue: pd.DataFrame) -> pd.DataFrame:
154
+ """Drops all objects with FLAG_PHOT != 0.
155
+
156
+ Args:
157
+ catalogue (pd.DataFrame): The input catalogue.
158
+
159
+ Returns:
160
+ pd.DataFrame: Cleaned catalogue.
161
+ """
162
+ catalogue = catalogue[catalogue["FLAG_PHOT"] == 0]
163
  return catalogue
164
+
165
+ def _correct_extinction(
166
+ self, catalogue: pd.DataFrame, f: np.ndarray, return_ext_corr: bool = False
167
+ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
168
+ """Corrects for extinction based on the provided catalogue.
169
+
170
+ Args:
171
+ catalogue (pd.DataFrame): The input catalogue.
172
+ f (np.ndarray): The flux values to correct.
173
+ return_ext_corr (bool): Whether to return the extinction correction values.
174
+
175
+ Returns:
176
+ Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Corrected fluxes, and optionally the extinction corrections.
177
+ """
178
  ext_correction = catalogue[self.columns_ebv].values
179
  f = f * ext_correction
180
  if return_ext_corr:
181
  return f, ext_correction
182
  else:
183
  return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ @staticmethod
186
+ def _select_only_zspec(
187
+ catalogue: pd.DataFrame, cat_flag: Optional[str] = None
188
+ ) -> pd.DataFrame:
189
+ """Selects only galaxies with spectroscopic redshift.
190
 
191
+ Args:
192
+ catalogue (pd.DataFrame): The input catalogue.
193
+ cat_flag (Optional[str]): Indicates the catalogue type ('Calib' or 'Valid').
194
+
195
+ Returns:
196
+ pd.DataFrame: Filtered catalogue.
197
+ """
198
+ if cat_flag == "Calib":
199
+ catalogue = catalogue[catalogue.specz > 0]
200
+ elif cat_flag == "Valid":
201
+ catalogue = catalogue[catalogue.specz > 0]
202
  return catalogue
203
+
204
+ @staticmethod
205
+ def take_zspec_and_photoz(catalogue: pd.DataFrame, cat_flag: Optional[str] = None
206
+ ) -> pd.DataFrame:
207
  """Selects only galaxies with spectroscopic redshift"""
208
  if cat_flag=='Calib':
209
  catalogue = catalogue[catalogue.target_z>0]
210
  elif cat_flag=='Valid':
211
  catalogue = catalogue[catalogue.specz>0]
212
  return catalogue
213
+
214
+ @staticmethod
215
+ def exclude_only_zspec(catalogue: pd.DataFrame) -> pd.DataFrame:
216
+ """Selects only galaxies without spectroscopic redshift.
217
 
218
+ Args:
219
+ catalogue (pd.DataFrame): The input catalogue.
 
 
 
 
220
 
221
+ Returns:
222
+ pd.DataFrame: Filtered catalogue.
223
+ """
224
+ catalogue = catalogue[
225
+ (catalogue.specz < 0) & (catalogue.photo_z > 0) & (catalogue.photo_z < 4)
226
+ ]
227
  return catalogue
 
228
 
229
+ @staticmethod
230
+ def _clean_zspec_sample(catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
231
+ catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
232
+ return catalogue
233
 
234
+ @staticmethod
235
+ def _select_L15_sample(self, catalogue: pd.DataFrame) -> pd.DataFrame:
236
+ """Selects only galaxies within a specific redshift range.
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ Args:
239
+ catalogue (pd.DataFrame): The input catalogue.
240
 
241
+ Returns:
242
+ pd.DataFrame: Filtered catalogue.
243
+ """
244
+ catalogue = catalogue[(catalogue.target_z > 0) & (catalogue.target_z < 3)]
245
+ return catalogue
246
 
247
+ def _set_training_data(self,
248
+ catalogue: pd.DataFrame,
249
+ catalogue_da: pd.DataFrame,
250
+ only_zspec: bool = True,
251
+ extinction_corr: bool = True,
252
+ convert_colors: bool = True
253
+ )-> None:
254
+
255
+ cat_da = Archive.exclude_only_zspec(catalogue_da)
256
  target_z_train_DA = cat_da['photo_z'].values
257
 
258
 
259
  if only_zspec:
260
  logger.info("Selecting only galaxies with spectroscopic redshift")
261
+ catalogue = Archive._select_only_zspec(catalogue, cat_flag='Calib')
262
+ catalogue = Archive._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
263
  else:
264
  logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
265
+ catalogue = Archive.take_zspec_and_photoz(catalogue, cat_flag='Calib')
266
 
267
 
268
  self.cat_train=catalogue
 
295
  self.target_z_train = catalogue['target_z'].values
296
 
297
  self.VIS_mag_train = catalogue['MAG_VIS'].values
298
+
299
+
300
+ def _set_testing_data(
301
+ self,
302
+ cat_test: pd.DataFrame,
303
+ target: str = "specz",
304
+ extinction_corr: bool = True,
305
+ convert_colors: bool = True,
306
+ ) -> None:
307
 
308
  if target=='specz':
309
+ cat_test = Archive._select_only_zspec(cat_test, cat_flag='Valid')
310
+ cat_test = Archive._clean_zspec_sample(cat_test)
311
+ self.target_z_test = cat_test['specz'].values
312
 
313
  elif target=='L15':
314
+ cat_test = self._select_L15_sample(cat_test)
315
+ self.target_z_test = cat_test['target_z'].values
316
 
317
 
318
+ self.cat_test=cat_test
319
 
320
+ f = self._extract_fluxes(cat_test)
321
 
322
  if extinction_corr==True:
323
+ f = self._correct_extinction(cat_test,f)
324
 
325
  if convert_colors==True:
326
  col = self._to_colors(f)
 
329
  self.phot_test = f
330
 
331
 
332
+ self.VIS_mag_test = cat_test['MAG_VIS'].values
333
+
334
+
335
  def get_training_data(self):
336
  return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
337
 
 
339
  return self.phot_test, self.target_z_test, self.VIS_mag_test
340
 
341
  def get_VIS_mag(self, catalogue):
342
+ return catalogue[['MAG_VIS']].values
temps/plots.py CHANGED
@@ -2,127 +2,185 @@ import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from temps.utils import nmad
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
  from scipy import stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
10
- #plot properties
11
- plt.rcParams['font.family'] = 'serif'
12
- plt.rcParams['font.size'] = 12
13
-
14
- if xvariable == 'VISmag':
15
- xvariable_lab = 'VIS'
16
- if xvariable == 'zs':
17
- xvariable_lab = r'$z_{\rm s}$'
18
-
19
- bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
20
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
21
- #plt.figure(figsize=(6, 5))
22
- ls = ['--',':','-']
23
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), gridspec_kw={'height_ratios': [3, 1]})
24
-
25
- ydata_dict = {}
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  for i, df in enumerate(df_list):
28
  ydata, xlab = [], []
29
-
30
  label = label_list[i]
31
-
32
- if label == 'zs':
33
- label_lab = r'$z_{\rm s}$'
34
- if label == 'zs+L15':
35
- label_lab = r'$z_{\rm s}$+L15'
36
- if label == 'TEMPS':
37
- label_lab = 'TEMPS'
38
-
39
- for k in range(len(bin_edges)-1):
40
- edge_min = bin_edges[k]
41
- edge_max = bin_edges[k+1]
42
 
 
 
 
43
  mean_mag = (edge_max + edge_min) / 2
44
 
45
- if type_bin == 'bin':
46
- df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
47
- elif type_bin == 'cum':
48
- df_plot = df[(df[xvariable] < edge_max)]
49
- else:
50
- raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
51
 
52
  xlab.append(mean_mag)
53
- if metric == 'sig68':
54
  ydata.append(sigma68(df_plot.zwerr))
55
- elif metric == 'bias':
56
  ydata.append(np.mean(df_plot.zwerr))
57
- elif metric == 'nmad':
58
  ydata.append(nmad(df_plot.zwerr))
59
- elif metric == 'outliers':
60
- ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
61
-
62
- ydata_dict[f'{i}'] = ydata
63
- color = cmap(i) # Get a different color for each dataframe
64
- ax1.plot(xlab, ydata,marker='.', lw=1, label=label_lab, color=color, ls=ls[i])
65
-
66
-
67
-
68
- ax1.set_ylabel(f'{metric} $[\Delta z]$', fontsize=18)
69
- #ax1.set_xlabel(f'{xvariable_lab}', fontsize=16)
 
 
 
 
 
 
 
70
  ax1.grid(False)
71
  ax1.legend()
72
-
73
- # Plot ratios between lines in the upper panel
74
-
75
- ax2.plot(xlab, np.array(ydata_dict['1'])/np.array(ydata_dict['0']), marker='.', color = cmap(1))
76
- ax2.plot(xlab, np.array(ydata_dict['2'])/np.array(ydata_dict['0']), marker='.', color = cmap(2))
77
- ax2.set_ylabel(r'Method $X$ / $z_{\rm z}$', fontsize=14)
78
-
79
-
80
- ax2.set_xlabel(f'{xvariable_lab}', fontsize=16)
 
 
 
 
 
 
 
81
  ax2.grid(True)
82
 
83
-
84
- if save==True:
85
- plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
86
  plt.show()
87
 
88
 
89
- def plot_pz(m, pz, specz):
90
- # Create a figure and axis
91
- fig, ax = plt.subplots(figsize=(8, 6))
92
 
93
- # Plot the PDF with a label
94
- ax.plot(np.linspace(0, 4, 1000), pz[m], label='PDF', color='navy')
 
 
95
 
96
- # Add a vertical line for 'specz_test'
97
- ax.axvline(specz[m], color='black', linestyle='--', label=r'$z_{\rm s}$')
 
 
 
 
 
 
 
 
98
 
99
- # Add labels and a legend
100
- ax.set_xlabel(r'$z$', fontsize = 18)
101
- ax.set_ylabel('Probability Density', fontsize=16)
102
- ax.legend(fontsize = 18)
103
 
104
- # Display the plot
105
- plt.show()
 
106
 
107
-
108
- def plot_zdistribution(archive, plot_test=False, bins=50):
109
- _,_,specz = archive.get_training_data()
110
- plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
111
 
112
- if plot_test:
113
- _,_,specz_test = archive.get_training_data()
114
- plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
 
 
115
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  plt.xticks(fontsize=12)
118
  plt.yticks(fontsize=12)
 
 
 
 
119
 
120
- plt.xlabel(r'Redshift', fontsize=14)
121
- plt.ylabel('Counts', fontsize=14)
122
 
123
- plt.show()
124
-
125
- def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
126
  """
127
  Plot the Self-Organizing Map (SOM) data.
128
 
@@ -135,182 +193,159 @@ def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
135
  Returns:
136
  None
137
  """
138
- plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap='viridis') # Choose an appropriate colormap
139
- plt.colorbar(label=f'{plot_arg}') # Add a colorbar with a label
140
- plt.xlabel(r'$x$ [pixel]', fontsize=14) # Add an appropriate X-axis label
141
- plt.ylabel(r'$y$ [pixel]', fontsize=14) # Add an appropriate Y-axis label
142
  plt.show()
143
 
144
-
145
- def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', labels=None, save =True):
146
- #plot properties
147
- plt.rcParams['font.family'] = 'serif'
148
- plt.rcParams['font.size'] = 12
149
- fig, ax = plt.subplots(figsize=(8, 6))
150
- kwargs=dict(bins=30, histtype='step', density=True, range=(0,1))
151
- cmap = plt.get_cmap('Dark2')
152
-
153
 
154
- # Create a histogram
155
- hist, bins, _ = ax.hist(pit_list_1, color=cmap(0), ls='--', **kwargs, label=labels[0])
156
- if pit_list_2!= None:
157
- hist, bins, _ = ax.hist(pit_list_2, color=cmap(1), ls=':', **kwargs, label=labels[1])
158
- if pit_list_3!= None:
159
- hist, bins, _ = ax.hist(pit_list_3, color=cmap(2), ls='-', **kwargs, label=labels[2])
 
 
 
 
160
 
161
-
162
- # Add labels and a title
163
- ax.set_xlabel('PIT values', fontsize = 18)
164
- ax.set_ylabel('Frequency', fontsize = 18)
 
 
 
165
 
166
- # Add grid lines
167
- ax.grid(True, linestyle='--', alpha=0.7)
 
 
 
 
 
 
168
 
169
- # Customize the x-axis
170
- ax.set_xlim(0, 1)
171
- #ax.set_ylim(0,3)
172
-
173
- plt.legend(fontsize=12)
 
174
 
175
- # Make ticks larger
176
- ax.tick_params(axis='both', which='major', labelsize=14)
177
- if save==True:
178
- plt.savefig(f'{sample}_PIT.pdf', bbox_inches='tight')
179
 
180
- # Show the plot
 
181
  plt.show()
182
-
183
 
184
 
185
-
186
- def plot_nz(df_list,
187
- zcuts = [0.1, 0.5, 1, 1.5, 2, 3, 4],
188
- save=False):
189
- # Plot properties
190
- plt.rcParams['font.family'] = 'serif'
191
- plt.rcParams['font.size'] = 16
192
 
193
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
194
-
195
- # Create subplots
196
- fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True)
197
 
198
- for i, df in enumerate(df_list):
199
- dfplot = df_list[i].copy() # Assuming df_list contains dataframes
200
- ax = axs[i] # Selecting the appropriate subplot
201
-
202
- for iz in range(len(zcuts)-1):
203
- dfplot_z = dfplot[(dfplot['ztarget'] > zcuts[iz]) & (dfplot['ztarget'] < zcuts[iz + 1])]
204
- color = cmap(iz) # Get a different color for each redshift
205
-
206
- zt_mean = np.median(dfplot_z.ztarget.values)
207
- zp_mean = np.median(dfplot_z.z.values)
208
-
209
-
210
- # Plot histogram on the selected subplot
211
- ax.hist(dfplot_z.z, bins=50, color=color, histtype='step', linestyle='-', density=True, range=(0, 4))
212
- ax.axvline(zt_mean, color=color, linestyle='-', lw=2)
213
- ax.axvline(zp_mean, color=color, linestyle='--', lw=2)
214
-
215
- ax.set_ylabel(f'Frequency', fontsize=14)
216
- ax.grid(False)
217
- ax.set_xlim(0, 3.5)
218
-
219
- axs[-1].set_xlabel(f'$z$', fontsize=18)
220
-
221
- if save:
222
- plt.savefig(f'nz_hist.pdf', dpi=300, bbox_inches='tight')
223
-
224
  plt.show()
225
 
226
-
227
-
228
 
229
- def plot_crps(crps_list_1, crps_list_2 = None, crps_list_3=None, labels=None, sample='specz', save =True):
 
 
 
 
 
 
 
230
  # Create a figure and axis
231
- #plot properties
232
- plt.rcParams['font.family'] = 'serif'
233
- plt.rcParams['font.size'] = 12
234
  fig, ax = plt.subplots(figsize=(8, 6))
235
- cmap = plt.get_cmap('Dark2')
236
 
237
- kwargs=dict(bins=50, histtype='step', density=True, range=(0,1))
238
 
239
  # Create a histogram
240
- hist, bins, _ = ax.hist(crps_list_1, color=cmap(0), ls='--', **kwargs, label=labels[0])
 
 
241
  if crps_list_2 is not None:
242
- hist, bins, _ = ax.hist(crps_list_2, color=cmap(1), ls=':', **kwargs, label=labels[1])
 
 
243
  if crps_list_3 is not None:
244
- hist, bins, _ = ax.hist(crps_list_3, color=cmap(2), ls='-', **kwargs, label=labels[2])
 
 
245
 
246
  # Add labels and a title
247
- ax.set_xlabel('CRPS Scores', fontsize = 18)
248
- ax.set_ylabel('Frequency', fontsize = 18)
249
 
250
  # Add grid lines
251
- ax.grid(True, linestyle='--', alpha=0.7)
252
 
253
  # Customize the x-axis
254
  ax.set_xlim(0, 0.5)
255
 
256
  # Make ticks larger
257
- ax.tick_params(axis='both', which='major', labelsize=14)
258
 
259
  # Calculate the mean CRPS value
260
  mean_crps_1 = round(np.nanmean(crps_list_1), 2)
261
  mean_crps_2 = round(np.nanmean(crps_list_2), 2)
262
  mean_crps_3 = round(np.nanmean(crps_list_3), 2)
263
 
264
-
265
  # Add the mean CRPS value at the top-left corner
266
- ax.annotate(f"Mean CRPS {labels[0]}: {mean_crps_1}", xy=(0.57, 0.9), xycoords='axes fraction', fontsize=14, color =cmap(0))
267
- ax.annotate(f"Mean CRPS {labels[1]}: {mean_crps_2}", xy=(0.57, 0.85), xycoords='axes fraction', fontsize=14, color =cmap(1))
268
- ax.annotate(f"Mean CRPS {labels[2]}: {mean_crps_3}", xy=(0.57, 0.8), xycoords='axes fraction', fontsize=14, color =cmap(2))
269
-
270
-
271
- if save==True:
272
- plt.savefig(f'{sample}_CRPS.pdf', bbox_inches='tight')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # Show the plot
275
  plt.show()
276
-
277
-
278
-
279
- def plot_nz(df, bins=np.arange(0,5,0.2)):
280
- kwargs=dict( bins=bins,alpha=0.5)
281
- plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
282
- counts, _, =np.histogram(df.z.values, bins=bins)
283
-
284
- plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
285
-
286
- #plt.legend(fontsize=14)
287
- plt.xlabel(r'Redshift', fontsize=14)
288
- plt.ylabel(r'Counts', fontsize=14)
289
- plt.yscale('log')
290
-
291
- plt.show()
292
-
293
- return
294
-
295
-
296
- def plot_scatter(df, sample='specz', save=True):
297
- # Calculate the point density
298
- xy = np.vstack([df.zs.values,df.z.values])
299
- zd = gaussian_kde(xy)(xy)
300
-
301
- fig, ax = plt.subplots()
302
- plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
303
- plt.xlim(0,5)
304
- plt.ylim(0,5)
305
-
306
- plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
307
- plt.ylabel('$z$', fontsize = 14)
308
-
309
- plt.xticks(fontsize = 12)
310
- plt.yticks(fontsize = 12)
311
-
312
- if save==True:
313
- plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
314
-
315
- plt.show()
316
-
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from temps.utils import nmad
 
 
5
  from scipy import stats
6
+ from typing import List, Optional, Dict
7
+
8
+
9
+ def plot_photoz(
10
+ df_list: List[pd.DataFrame],
11
+ nbins: int,
12
+ xvariable: str,
13
+ metric: str,
14
+ type_bin: str = "bin",
15
+ label_list: Optional[List[str]] = None,
16
+ samp: str = "zs",
17
+ save: bool = False,
18
+ ) -> None:
19
+ """
20
+ Plot photo-z metrics for multiple dataframes.
21
+
22
+ Parameters:
23
+ - df_list (List[pd.DataFrame]): List of dataframes containing data for plotting.
24
+ - nbins (int): Number of bins for the histogram.
25
+ - xvariable (str): Variable to plot on the x-axis.
26
+ - metric (str): Metric to plot (e.g., 'sig68', 'bias', 'nmad', 'outliers').
27
+ - type_bin (str, optional): Type of binning ('bin' or 'cum'). Default is 'bin'.
28
+ - label_list (Optional[List[str]], optional): List of labels for each dataframe. Default is None.
29
+ - samp (str, optional): Sample label for saving. Default is 'zs'.
30
+ - save (bool, optional): If True, save the plot to a file. Default is False.
31
+
32
+ Returns:
33
+ None
34
+ """
35
+ # Plot properties
36
+ plt.rcParams["font.family"] = "serif"
37
+ plt.rcParams["font.size"] = 12
38
 
39
+ # Set x-axis label based on variable
40
+ xvariable_lab = "VIS" if xvariable == "VISmag" else r"$z_{\rm s}$"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Calculate bin edges
43
+ bin_edges = stats.mstats.mquantiles(
44
+ df_list[0][xvariable].values, np.linspace(0.05, 1, nbins)
45
+ )
46
+ cmap = plt.get_cmap("Dark2")
47
+
48
+ # Create subplots
49
+ fig, (ax1, ax2) = plt.subplots(
50
+ 2, 1, figsize=(8, 8), gridspec_kw={"height_ratios": [3, 1]}
51
+ )
52
+ ydata_dict: Dict[str, List[float]] = {}
53
+
54
+ # Loop through dataframes and calculate metrics
55
  for i, df in enumerate(df_list):
56
  ydata, xlab = [], []
57
+
58
  label = label_list[i]
59
+ label_lab = {
60
+ "zs": r"$z_{\rm s}$",
61
+ "zs+L15": r"$z_{\rm s}$+L15",
62
+ "TEMPS": "TEMPS",
63
+ }.get(label, label)
 
 
 
 
 
 
64
 
65
+ for k in range(len(bin_edges) - 1):
66
+ edge_min = bin_edges[k]
67
+ edge_max = bin_edges[k + 1]
68
  mean_mag = (edge_max + edge_min) / 2
69
 
70
+ df_plot = (
71
+ df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
72
+ if type_bin == "bin"
73
+ else df[(df[xvariable] < edge_max)]
74
+ )
 
75
 
76
  xlab.append(mean_mag)
77
+ if metric == "sig68":
78
  ydata.append(sigma68(df_plot.zwerr))
79
+ elif metric == "bias":
80
  ydata.append(np.mean(df_plot.zwerr))
81
+ elif metric == "nmad":
82
  ydata.append(nmad(df_plot.zwerr))
83
+ elif metric == "outliers":
84
+ ydata.append(
85
+ len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot) * 100
86
+ )
87
+
88
+ ydata_dict[f"{i}"] = ydata
89
+ color = cmap(i)
90
+ ax1.plot(
91
+ xlab,
92
+ ydata,
93
+ marker=".",
94
+ lw=1,
95
+ label=label_lab,
96
+ color=color,
97
+ ls=["--", ":", "-"][i],
98
+ )
99
+
100
+ ax1.set_ylabel(f"{metric} $[\Delta z]$", fontsize=18)
101
  ax1.grid(False)
102
  ax1.legend()
103
+
104
+ # Plot ratios
105
+ ax2.plot(
106
+ xlab,
107
+ np.array(ydata_dict["1"]) / np.array(ydata_dict["0"]),
108
+ marker=".",
109
+ color=cmap(1),
110
+ )
111
+ ax2.plot(
112
+ xlab,
113
+ np.array(ydata_dict["2"]) / np.array(ydata_dict["0"]),
114
+ marker=".",
115
+ color=cmap(2),
116
+ )
117
+ ax2.set_ylabel(r"Method $X$ / $z_{\rm z}$", fontsize=14)
118
+ ax2.set_xlabel(f"{xvariable_lab}", fontsize=16)
119
  ax2.grid(True)
120
 
121
+ if save:
122
+ plt.savefig(f"{metric}_{xvariable}_{samp}.pdf", dpi=300, bbox_inches="tight")
 
123
  plt.show()
124
 
125
 
126
+ def plot_pz(m: int, pz: np.ndarray, specz: float) -> None:
127
+ """
128
+ Plot the Probability Density Function (PDF) for a given model and compare it with the spectroscopic redshift.
129
 
130
+ Parameters:
131
+ - m (int): Index for the model.
132
+ - pz (np.ndarray): Probability density function values.
133
+ - specz (float): Spectroscopic redshift value.
134
 
135
+ Returns:
136
+ None
137
+ """
138
+ fig, ax = plt.subplots(figsize=(8, 6))
139
+ ax.plot(np.linspace(0, 4, 1000), pz[m], label="PDF", color="navy")
140
+ ax.axvline(specz[m], color="black", linestyle="--", label=r"$z_{\rm s}$")
141
+ ax.set_xlabel(r"$z$", fontsize=18)
142
+ ax.set_ylabel("Probability Density", fontsize=16)
143
+ ax.legend(fontsize=18)
144
+ plt.show()
145
 
 
 
 
 
146
 
147
+ def plot_zdistribution(archive, plot_test: bool = False, bins: int = 50) -> None:
148
+ """
149
+ Plot the distribution of redshifts for training and optionally test samples.
150
 
151
+ Parameters:
152
+ - archive: Data archive object containing the training data.
153
+ - plot_test (bool, optional): If True, plot test sample distribution. Default is False.
154
+ - bins (int, optional): Number of histogram bins. Default is 50.
155
 
156
+ Returns:
157
+ None
158
+ """
159
+ _, _, specz = archive.get_training_data()
160
+ plt.hist(specz, bins=bins, histtype="step", color="navy", label=r"Training sample")
161
 
162
+ if plot_test:
163
+ _, _, specz_test = archive.get_training_data()
164
+ plt.hist(
165
+ specz_test,
166
+ bins=bins,
167
+ histtype="step",
168
+ color="goldenrod",
169
+ label=r"Test sample",
170
+ linestyle="--",
171
+ )
172
 
173
  plt.xticks(fontsize=12)
174
  plt.yticks(fontsize=12)
175
+ plt.xlabel(r"Redshift", fontsize=14)
176
+ plt.ylabel("Counts", fontsize=14)
177
+ plt.legend()
178
+ plt.show()
179
 
 
 
180
 
181
+ def plot_som_map(
182
+ som_data: np.ndarray, plot_arg: str = "z", vmin: float = 0, vmax: float = 1
183
+ ) -> None:
184
  """
185
  Plot the Self-Organizing Map (SOM) data.
186
 
 
193
  Returns:
194
  None
195
  """
196
+ plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap="viridis")
197
+ plt.colorbar(label=f"{plot_arg}")
198
+ plt.xlabel(r"$x$ [pixel]", fontsize=14)
199
+ plt.ylabel(r"$y$ [pixel]", fontsize=14)
200
  plt.show()
201
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def plot_PIT(
204
+ pit_list_1: List[float],
205
+ pit_list_2: Optional[List[float]] = None,
206
+ pit_list_3: Optional[List[float]] = None,
207
+ sample: str = "specz",
208
+ labels: Optional[List[str]] = None,
209
+ save: bool = True,
210
+ ) -> None:
211
+ """
212
+ Plot Probability Integral Transform (PIT) values for given lists.
213
 
214
+ Parameters:
215
+ - pit_list_1 (List[float]): First list of PIT values.
216
+ - pit_list_2 (Optional[List[float]], optional): Second list of PIT values. Default is None.
217
+ - pit_list_3 (Optional[List[float]], optional): Third list of PIT values. Default is None.
218
+ - sample (str, optional): Sample label for saving. Default is 'specz'.
219
+ - labels (Optional[List[str]], optional): List of labels for each PIT list. Default is None.
220
+ - save (bool, optional): If True, save the plot to a file. Default is True.
221
 
222
+ Returns:
223
+ None
224
+ """
225
+ plt.rcParams["font.family"] = "serif"
226
+ plt.rcParams["font.size"] = 12
227
+ fig, ax = plt.subplots(figsize=(8, 6))
228
+ kwargs = dict(bins=30, histtype="step", density=True, range=(0, 1))
229
+ cmap = plt.get_cmap("Dark2")
230
 
231
+ # Create a histogram
232
+ ax.hist(pit_list_1, color=cmap(0), linestyle="--", **kwargs, label=labels[0])
233
+ if pit_list_2 is not None:
234
+ ax.hist(pit_list_2, color=cmap(1), linestyle="--", **kwargs, label=labels[1])
235
+ if pit_list_3 is not None:
236
+ ax.hist(pit_list_3, color=cmap(2), linestyle="--", **kwargs, label=labels[2])
237
 
238
+ ax.set_xlabel("PIT values", fontsize=14)
239
+ ax.set_ylabel("Normalized Counts", fontsize=14)
240
+ ax.legend(fontsize=12)
 
241
 
242
+ if save:
243
+ plt.savefig(f"PIT_{sample}.pdf", dpi=300, bbox_inches="tight")
244
  plt.show()
 
245
 
246
 
247
+ def plot_outlier_ratio(
248
+ outliers: np.ndarray, num_samp: int = 100, plot_mean: bool = True
249
+ ) -> None:
250
+ """
251
+ Plot the outlier ratio as a function of the number of samples.
 
 
252
 
253
+ Parameters:
254
+ - outliers (np.ndarray): Outlier ratio data.
255
+ - num_samp (int, optional): Number of samples for plotting. Default is 100.
256
+ - plot_mean (bool, optional): If True, plot the mean of outliers. Default is True.
257
 
258
+ Returns:
259
+ None
260
+ """
261
+ plt.figure(figsize=(10, 6))
262
+ plt.plot(np.arange(1, num_samp + 1), outliers[:num_samp], label="Outlier Ratio")
263
+
264
+ if plot_mean:
265
+ plt.axhline(
266
+ np.mean(outliers), color="red", linestyle="--", label="Mean Outlier Ratio"
267
+ )
268
+
269
+ plt.xlabel("Number of Samples", fontsize=14)
270
+ plt.ylabel("Outlier Ratio", fontsize=14)
271
+ plt.legend()
272
+ plt.grid()
 
 
 
 
 
 
 
 
 
 
 
273
  plt.show()
274
 
 
 
275
 
276
+ def plot_crps(
277
+ crps_list_1: List[float],
278
+ crps_list_2: Optional[List[float]] = None,
279
+ crps_list_3: Optional[List[float]] = None,
280
+ label: Optional[List[str]] = None,
281
+ sample: str = "specz",
282
+ save: bool = True,
283
+ ) -> None:
284
  # Create a figure and axis
285
+ # plot properties
286
+ plt.rcParams["font.family"] = "serif"
287
+ plt.rcParams["font.size"] = 12
288
  fig, ax = plt.subplots(figsize=(8, 6))
289
+ cmap = plt.get_cmap("Dark2")
290
 
291
+ kwargs = dict(bins=50, histtype="step", density=True, range=(0, 1))
292
 
293
  # Create a histogram
294
+ hist, bins, _ = ax.hist(
295
+ crps_list_1, color=cmap(0), ls="--", **kwargs, label=labels[0]
296
+ )
297
  if crps_list_2 is not None:
298
+ hist, bins, _ = ax.hist(
299
+ crps_list_2, color=cmap(1), ls=":", **kwargs, label=labels[1]
300
+ )
301
  if crps_list_3 is not None:
302
+ hist, bins, _ = ax.hist(
303
+ crps_list_3, color=cmap(2), ls="-", **kwargs, label=labels[2]
304
+ )
305
 
306
  # Add labels and a title
307
+ ax.set_xlabel("CRPS Scores", fontsize=18)
308
+ ax.set_ylabel("Frequency", fontsize=18)
309
 
310
  # Add grid lines
311
+ ax.grid(True, linestyle="--", alpha=0.7)
312
 
313
  # Customize the x-axis
314
  ax.set_xlim(0, 0.5)
315
 
316
  # Make ticks larger
317
+ ax.tick_params(axis="both", which="major", labelsize=14)
318
 
319
  # Calculate the mean CRPS value
320
  mean_crps_1 = round(np.nanmean(crps_list_1), 2)
321
  mean_crps_2 = round(np.nanmean(crps_list_2), 2)
322
  mean_crps_3 = round(np.nanmean(crps_list_3), 2)
323
 
 
324
  # Add the mean CRPS value at the top-left corner
325
+ ax.annotate(
326
+ f"Mean CRPS {labels[0]}: {mean_crps_1}",
327
+ xy=(0.57, 0.9),
328
+ xycoords="axes fraction",
329
+ fontsize=14,
330
+ color=cmap(0),
331
+ )
332
+ ax.annotate(
333
+ f"Mean CRPS {labels[1]}: {mean_crps_2}",
334
+ xy=(0.57, 0.85),
335
+ xycoords="axes fraction",
336
+ fontsize=14,
337
+ color=cmap(1),
338
+ )
339
+ ax.annotate(
340
+ f"Mean CRPS {labels[2]}: {mean_crps_3}",
341
+ xy=(0.57, 0.8),
342
+ xycoords="axes fraction",
343
+ fontsize=14,
344
+ color=cmap(2),
345
+ )
346
+
347
+ if save == True:
348
+ plt.savefig(f"{sample}_CRPS.pdf", bbox_inches="tight")
349
 
350
  # Show the plot
351
  plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
temps/temps.py CHANGED
@@ -6,38 +6,63 @@ from torch.optim import lr_scheduler
6
  from loguru import logger
7
  import pandas as pd
8
  from scipy.stats import norm
9
-
10
- from tqdm import tqdm # Import tqdm for progress bars
 
11
 
12
  from temps.utils import maximum_mean_discrepancy
13
 
14
 
 
15
  class TempsModule:
16
- """Class for managing temperature-related models and training."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def __init__(
19
  self,
20
- model_f,
21
- model_z,
22
- batch_size=100,
23
- rejection_param=1,
24
- da=True,
25
- verbose=False,
26
- ):
27
- self.model_z = model_z
28
- self.model_f = model_f
29
- self.da = da
30
- self.verbose = verbose
31
- self.ngaussians = model_z.ngaussians
32
-
33
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- self.batch_size = batch_size
35
- self.rejection_parameter = rejection_param
36
 
37
- def _get_dataloaders(
38
- self, input_data, target_data, input_data_da=None, val_fraction=0.1
39
- ):
40
- """Create training and validation dataloaders."""
41
  input_data = torch.Tensor(input_data)
42
  target_data = torch.Tensor(target_data)
43
  input_data_da = (
@@ -47,10 +72,17 @@ class TempsModule:
47
  )
48
 
49
  dataset = TensorDataset(input_data, input_data_da, target_data)
 
 
 
 
 
 
50
  train_dataset, val_dataset = torch.utils.data.random_split(
51
  dataset,
52
- [int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)+1],
53
  )
 
54
  loader_train = DataLoader(
55
  train_dataset, batch_size=self.batch_size, shuffle=True
56
  )
@@ -58,8 +90,24 @@ class TempsModule:
58
 
59
  return loader_train, loader_val
60
 
61
- def _loss_function(self, mean, std, logmix, true):
62
- """Compute the loss function."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  log_prob = (
64
  logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
65
  )
@@ -67,28 +115,55 @@ class TempsModule:
67
  loss = -log_prob.mean()
68
  return loss
69
 
70
- def _loss_function_da(self, f1, f2):
71
- """Compute the KL divergence loss for domain adaptation."""
 
 
 
 
 
 
 
 
 
72
  kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
73
  loss = kl_loss(f1, f2)
74
  return torch.log(loss)
75
 
76
- def _to_numpy(self, x):
77
- """Convert a tensor to a NumPy array."""
 
 
 
 
 
 
 
78
  return x.detach().cpu().numpy()
79
 
80
  def train(
81
  self,
82
- input_data,
83
- input_data_da,
84
- target_data,
85
- nepochs=10,
86
- step_size=100,
87
- val_fraction=0.1,
88
- lr=1e-3,
89
- weight_decay=0,
90
- ):
91
- """Train the models using provided data."""
 
 
 
 
 
 
 
 
 
 
 
92
  self.model_z.train()
93
  self.model_f.train()
94
 
@@ -157,8 +232,11 @@ class TempsModule:
157
  f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
158
  )
159
 
160
- def _validate(self, loader_val, target_data):
 
 
161
  """Validate the model on the validation dataset."""
 
162
  self.model_z.eval()
163
  self.model_f.eval()
164
  _loss_validation = []
@@ -180,15 +258,49 @@ class TempsModule:
180
 
181
  return _loss_validation
182
 
183
- def get_features(self, input_data):
184
- """Get features from the model."""
 
 
 
 
 
 
 
 
185
  self.model_f.eval()
186
  input_data = input_data.to(self.device)
187
  features = self.model_f(input_data)
188
  return self._to_numpy(features)
189
 
190
- def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
191
- """Get the predicted z values and their uncertainties."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  logger.info("Predicting photo-z for the input galaxies...")
193
  self.model_z.eval().to(self.device)
194
  self.model_f.eval().to(self.device)
@@ -206,6 +318,7 @@ class TempsModule:
206
  + (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
207
  )
208
 
 
209
  mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
210
 
211
  if return_pz:
@@ -214,118 +327,61 @@ class TempsModule:
214
  else:
215
  return self._to_numpy(z), self._to_numpy(zerr)
216
 
217
- def _calculate_pdf(self, z, mu, sig, mix_coeff, return_flag):
218
- """Calculate the probability density function."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  zgrid = np.linspace(0, 5, 1000)
220
  pz = np.zeros((len(z), len(zgrid)))
221
 
222
  for ii in range(len(z)):
223
  for i in range(self.ngaussians):
224
- pz[ii] += mix_coeff[ii, i] * norm.pdf(
225
- zgrid, mu[ii, i], sig[ii, i]
226
- )
227
 
228
  if return_flag:
229
  logger.info("Calculating and returning ODDS")
230
  pz /= pz.sum(axis=1, keepdims=True)
231
  return self._calculate_odds(z, pz, zgrid)
232
- return self._to_numpy(z), pz
233
-
234
- def _calculate_odds(self, z, pz, zgrid):
235
- """Calculate odds based on the PDF."""
236
- logger.info('Calculating ODDS values')
237
- diff_matrix = np.abs(self._to_numpy(z)[:, None] - zgrid[None, :])
238
- idx_peak = np.argmax(pz, axis=1)
239
- zpeak = zgrid[idx_peak]
240
- idx_upper = np.argmin(np.abs((zpeak + 0.05)[:, None] - zgrid[None, :]), axis=1)
241
- idx_lower = np.argmin(np.abs((zpeak - 0.05)[:, None] - zgrid[None, :]), axis=1)
242
-
243
- odds = []
244
- for jj in range(len(pz)):
245
- odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
246
-
247
- odds = np.array(odds)
248
- return self._to_numpy(z), pz, odds
249
-
250
- def calculate_pit(self, input_data, target_data):
251
- logger.info('Calculating PIT values')
252
-
253
- pit_list = []
254
-
255
- self.model_f = self.model_f.eval()
256
- self.model_f = self.model_f.to(self.device)
257
- self.model_z = self.model_z.eval()
258
- self.model_z = self.model_z.to(self.device)
259
-
260
- input_data = input_data.to(self.device)
261
-
262
-
263
- features = self.model_f(input_data)
264
- mu, logsig, logmix_coeff = self.model_z(features)
265
-
266
- logsig = torch.clamp(logsig,-6,2)
267
- sig = torch.exp(logsig)
268
-
269
- mix_coeff = torch.exp(logmix_coeff)
270
-
271
- mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
272
-
273
- for ii in range(len(input_data)):
274
- pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
275
- pit_list.append(pit)
276
-
277
-
278
- return pit_list
279
-
280
- def calculate_crps(self, input_data, target_data):
281
- logger.info('Calculating CRPS values')
282
-
283
- def measure_crps(cdf, t):
284
- zgrid = np.linspace(0,4,1000)
285
- Deltaz = zgrid[None,:] - t[:,None]
286
- DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
287
- integral = (cdf-DeltaZ_heaviside)**2
288
- crps_value = integral.sum(1) / 1000
289
-
290
- return crps_value
291
-
292
-
293
- crps_list = []
294
-
295
- self.model_f = self.model_f.eval()
296
- self.model_f = self.model_f.to(self.device)
297
- self.model_z = self.model_z.eval()
298
- self.model_z = self.model_z.to(self.device)
299
-
300
- input_data = input_data.to(self.device)
301
-
302
-
303
- features = self.model_f(input_data)
304
- mu, logsig, logmix_coeff = self.model_z(features)
305
- logsig = torch.clamp(logsig,-6,2)
306
- sig = torch.exp(logsig)
307
-
308
- mix_coeff = torch.exp(logmix_coeff)
309
-
310
-
311
- mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
312
-
313
- z = (mix_coeff * mu).sum(1)
314
-
315
- x = np.linspace(0, 4, 1000)
316
- pz = np.zeros(shape=(len(target_data), len(x)))
317
- for ii in range(len(input_data)):
318
- for i in range(6):
319
- pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
320
-
321
- pz = pz / pz.sum(1)[:,None]
322
-
323
-
324
- cdf_z = np.cumsum(pz,1)
325
-
326
- crps_value = measure_crps(cdf_z, target_data)
327
-
328
-
329
-
330
- return crps_value
331
-
 
6
  from loguru import logger
7
  import pandas as pd
8
  from scipy.stats import norm
9
+ from dataclasses import dataclass, field
10
+ from tqdm import tqdm
11
+ from typing import Optional, Tuple, List, Union
12
 
13
  from temps.utils import maximum_mean_discrepancy
14
 
15
 
16
+ @dataclass
17
  class TempsModule:
18
+ """Attributes:
19
+ model_f (nn.Module): The feature extraction model.
20
+ model_z (nn.Module): The model for predicting z values.
21
+ batch_size (int): Size of each batch for training. Default is 100.
22
+ rejection_param (int): Parameter for rejection sampling. Default is 1.
23
+ da (bool): Flag for enabling domain adaptation. Default is True.
24
+ verbose (bool): Flag for verbose logging. Default is False.
25
+ device (torch.device): Device to run the model on (CPU or GPU).
26
+ ngaussians (int): Number of Gaussian components in the mixture model.
27
+ """
28
+
29
+ model_f: nn.Module
30
+ model_z: nn.Module
31
+ batch_size: int = 100
32
+ rejection_param: int = 1
33
+ da: bool = True
34
+ verbose: bool = False
35
+ device: torch.device = field(init=False)
36
+ ngaussians: int = field(init=False)
37
+
38
+ def __post_init__(self) -> None:
39
+ """Post-initialization for setting up additional attributes."""
40
+ self.device: torch.device = torch.device(
41
+ "cuda" if torch.cuda.is_available() else "cpu"
42
+ )
43
+ self.ngaussians: int = (
44
+ self.model_z.ngaussians
45
+ ) # Assuming ngaussians is an integer
46
 
47
+ def _get_dataloaders(
48
  self,
49
+ input_data: np.ndarray,
50
+ target_data: np.ndarray,
51
+ input_data_da: Optional[np.ndarray] = None,
52
+ val_fraction: float = 0.1,
53
+ ) -> Tuple[DataLoader, DataLoader]:
54
+ """Create training and validation dataloaders.
55
+
56
+ Args:
57
+ input_data (np.ndarray): The input features for training.
58
+ target_data (np.ndarray): The target outputs for training.
59
+ input_data_da (Optional[np.ndarray]): Input data for domain adaptation (if any).
60
+ val_fraction (float): Fraction of data to use for validation. Default is 0.1.
61
+
62
+ Returns:
63
+ Tuple[DataLoader, DataLoader]: Training and validation data loaders.
64
+ """
65
 
 
 
 
 
66
  input_data = torch.Tensor(input_data)
67
  target_data = torch.Tensor(target_data)
68
  input_data_da = (
 
72
  )
73
 
74
  dataset = TensorDataset(input_data, input_data_da, target_data)
75
+
76
+ # Calculate sizes for training and validation sets
77
+ total_size = len(dataset)
78
+ val_size = int(total_size * val_fraction)
79
+ train_size = total_size - val_size
80
+
81
  train_dataset, val_dataset = torch.utils.data.random_split(
82
  dataset,
83
+ [train_size, val_size],
84
  )
85
+
86
  loader_train = DataLoader(
87
  train_dataset, batch_size=self.batch_size, shuffle=True
88
  )
 
90
 
91
  return loader_train, loader_val
92
 
93
+ def _loss_function(
94
+ self,
95
+ mean: torch.Tensor,
96
+ std: torch.Tensor,
97
+ logmix: torch.Tensor,
98
+ true: torch.Tensor,
99
+ ) -> torch.Tensor:
100
+ """Compute the loss function for the model.
101
+
102
+ Args:
103
+ mean (torch.Tensor): Mean values predicted by the model.
104
+ std (torch.Tensor): Standard deviation values predicted by the model.
105
+ logmix (torch.Tensor): Logarithm of the mixture coefficients.
106
+ true (torch.Tensor): True target values.
107
+
108
+ Returns:
109
+ torch.Tensor: The computed loss value.
110
+ """
111
  log_prob = (
112
  logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
113
  )
 
115
  loss = -log_prob.mean()
116
  return loss
117
 
118
+ def _loss_function_da(self, f1: torch.Tensor, f2: torch.Tensor) -> torch.Tensor:
119
+ """Compute the KL divergence loss for domain adaptation.
120
+
121
+ Args:
122
+ f1 (torch.Tensor): Features from the primary domain.
123
+ f2 (torch.Tensor): Features from the domain for adaptation.
124
+
125
+ Returns:
126
+ torch.Tensor: The KL divergence loss value.
127
+ """
128
+
129
  kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
130
  loss = kl_loss(f1, f2)
131
  return torch.log(loss)
132
 
133
+ def _to_numpy(self, x: torch.Tensor) -> np.ndarray:
134
+ """Convert a tensor to a NumPy array.
135
+
136
+ Args:
137
+ x (torch.Tensor): The input tensor to convert.
138
+
139
+ Returns:
140
+ np.ndarray: The converted NumPy array.
141
+ """
142
  return x.detach().cpu().numpy()
143
 
144
  def train(
145
  self,
146
+ input_data: np.ndarray,
147
+ input_data_da: np.ndarray,
148
+ target_data: np.ndarray,
149
+ nepochs: int = 10,
150
+ step_size: int = 100,
151
+ val_fraction: float = 0.1,
152
+ lr: float = 1e-3,
153
+ weight_decay: float = 0,
154
+ ) -> None:
155
+ """Train the models using provided data.
156
+
157
+ Args:
158
+ input_data (np.ndarray): The input features for training.
159
+ input_data_da (np.ndarray): Input data for domain adaptation.
160
+ target_data (np.ndarray): The target outputs for training.
161
+ nepochs (int): Number of training epochs. Default is 10.
162
+ step_size (int): Step size for learning rate scheduling. Default is 100.
163
+ val_fraction (float): Fraction of data to use for validation. Default is 0.1.
164
+ lr (float): Learning rate for the optimizer. Default is 1e-3.
165
+ weight_decay (float): Weight decay for regularization. Default is 0.
166
+ """
167
  self.model_z.train()
168
  self.model_f.train()
169
 
 
232
  f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
233
  )
234
 
235
+ def _validate(
236
+ self, loader_val: DataLoader, target_data: torch.Tensor
237
+ ) -> List[float]:
238
  """Validate the model on the validation dataset."""
239
+
240
  self.model_z.eval()
241
  self.model_f.eval()
242
  _loss_validation = []
 
258
 
259
  return _loss_validation
260
 
261
+ def get_features(self, input_data: torch.Tensor) -> np.ndarray:
262
+ """Extract features from the model for the given input data.
263
+
264
+ Args:
265
+ input_data (torch.Tensor): Input tensor containing the data for which features are to be extracted.
266
+
267
+ Returns:
268
+ np.ndarray: Numpy array of extracted features from the model.
269
+ """
270
+
271
  self.model_f.eval()
272
  input_data = input_data.to(self.device)
273
  features = self.model_f(input_data)
274
  return self._to_numpy(features)
275
 
276
+ def get_pz(
277
+ self,
278
+ input_data: torch.Tensor,
279
+ return_pz: bool = True,
280
+ return_flag: bool = True,
281
+ return_odds: bool = False,
282
+ ) -> Union[
283
+ Tuple[np.ndarray, np.ndarray], # Return z and zerr
284
+ Tuple[np.ndarray, np.ndarray], # Return z, pz
285
+ Tuple[np.ndarray, np.ndarray, np.ndarray] # Return z, pz, odds
286
+ ]:
287
+ """Get the predicted redshift (z) values and their uncertainties from the model.
288
+
289
+ This function predicts the photo-z for the input galaxies, computes the mean and standard
290
+ deviation for the predicted redshifts, and optionally calculates the probability density function (PDF).
291
+
292
+ Args:
293
+ input_data (torch.Tensor): Input tensor containing galaxy data for which to predict redshifts.
294
+ return_pz (bool, optional): Flag indicating whether to return the probability density function. Defaults to True.
295
+ return_flag (bool, optional): Flag indicating whether to return additional information. Defaults to True.
296
+ return_odds (bool, optional): Flag indicating whether to return the odds. Defaults to False.
297
+
298
+ Returns:
299
+ Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
300
+ - If return_pz is True, returns the PDF and possibly additional metrics.
301
+ - If return_pz is False, returns a tuple containing the predicted redshifts and their uncertainties.
302
+ """
303
+
304
  logger.info("Predicting photo-z for the input galaxies...")
305
  self.model_z.eval().to(self.device)
306
  self.model_f.eval().to(self.device)
 
318
  + (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
319
  )
320
 
321
+ z = self._to_numpy(z)
322
  mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
323
 
324
  if return_pz:
 
327
  else:
328
  return self._to_numpy(z), self._to_numpy(zerr)
329
 
330
+ def _calculate_pdf(
331
+ self,
332
+ z: np.ndarray,
333
+ mu: np.ndarray,
334
+ sig: np.ndarray,
335
+ mix_coeff: np.ndarray,
336
+ return_flag: bool,
337
+ ) -> Union[
338
+ Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]
339
+ ]:
340
+ """Calculate the probability density function (PDF) for the predicted redshifts.
341
+
342
+ Args:
343
+ z (np.ndarray): Predicted redshift values.
344
+ mu (np.ndarray): Mean values for the Gaussian components.
345
+ sig (np.ndarray): Standard deviations for the Gaussian components.
346
+ mix_coeff (np.ndarray): Mixture coefficients for the Gaussian components.
347
+ return_flag (bool): Flag indicating whether to calculate and return odds.
348
+
349
+ Returns:
350
+ Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
351
+ - If return_flag is True, returns a tuple containing the redshift values, PDF, and the z-grid.
352
+ - If return_flag is False, returns a tuple containing the redshift values and PDF.
353
+ """
354
+
355
  zgrid = np.linspace(0, 5, 1000)
356
  pz = np.zeros((len(z), len(zgrid)))
357
 
358
  for ii in range(len(z)):
359
  for i in range(self.ngaussians):
360
+ pz[ii] += mix_coeff[ii, i] * norm.pdf(zgrid, mu[ii, i], sig[ii, i])
 
 
361
 
362
  if return_flag:
363
  logger.info("Calculating and returning ODDS")
364
  pz /= pz.sum(axis=1, keepdims=True)
365
  return self._calculate_odds(z, pz, zgrid)
366
+ return z, pz
367
+
368
+ def _calculate_odds(
369
+ self, z: np.ndarray, pz: np.ndarray, zgrid: np.ndarray
370
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
371
+ """Calculate the odds for the estimated redshifts based on the cumulative distribution.
372
+
373
+ Args:
374
+ z (np.ndarray): Predicted redshift values.
375
+ pz (np.ndarray): Probability density function values.
376
+ zgrid (np.ndarray): Grid of redshift values for evaluation.
377
+
378
+ Returns:
379
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the predicted redshift values,
380
+ PDF values, and calculated odds.
381
+ """
382
+
383
+ cumulative = np.cumsum(pz, axis=1)
384
+ odds = np.array(
385
+ [np.max(np.abs(cumulative[i] - 0.68)) for i in range(cumulative.shape[0])]
386
+ )
387
+ return z, pz, odds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
temps/temps_arch.py CHANGED
@@ -1,10 +1,24 @@
1
  import torch
2
- from torch import nn, optim
3
  import torch.nn.functional as F
4
 
5
 
6
  class EncoderPhotometry(nn.Module):
7
- def __init__(self, input_dim=6, dropout_prob=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  super(EncoderPhotometry, self).__init__()
9
 
10
  self.features = nn.Sequential(
@@ -23,14 +37,39 @@ class EncoderPhotometry(nn.Module):
23
  nn.Linear(20, 10),
24
  )
25
 
26
- def forward(self, x):
 
 
 
 
 
 
 
 
27
  f = self.features(x)
28
  f = F.log_softmax(f, dim=1)
29
  return f
30
 
31
 
32
  class MeasureZ(nn.Module):
33
- def __init__(self, num_gauss=10, dropout_prob=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  super(MeasureZ, self).__init__()
35
 
36
  self.ngaussians = num_gauss
@@ -55,11 +94,25 @@ class MeasureZ(nn.Module):
55
  nn.Linear(20, num_gauss),
56
  )
57
 
58
- def forward(self, f):
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  mu = self.measure_mu(f)
60
  sigma = self.measure_sigma(f)
61
  logmix_coeff = self.measure_coeffs(f)
62
 
63
- logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:, None]
 
64
 
65
  return mu, sigma, logmix_coeff
 
1
  import torch
2
+ from torch import nn
3
  import torch.nn.functional as F
4
 
5
 
6
  class EncoderPhotometry(nn.Module):
7
+ """Encoder for photometric data.
8
+
9
+ This neural network encodes photometric features into a lower-dimensional representation.
10
+
11
+ Attributes:
12
+ features (nn.Sequential): A sequential container of layers used for encoding.
13
+ """
14
+
15
+ def __init__(self, input_dim: int = 6, dropout_prob: float = 0) -> None:
16
+ """Initializes the EncoderPhotometry module.
17
+
18
+ Args:
19
+ input_dim (int): Number of input features (default is 6).
20
+ dropout_prob (float): Probability of dropout (default is 0).
21
+ """
22
  super(EncoderPhotometry, self).__init__()
23
 
24
  self.features = nn.Sequential(
 
37
  nn.Linear(20, 10),
38
  )
39
 
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ """Forward pass through the encoder.
42
+
43
+ Args:
44
+ x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
45
+
46
+ Returns:
47
+ torch.Tensor: Log softmax output of shape (batch_size, 10).
48
+ """
49
  f = self.features(x)
50
  f = F.log_softmax(f, dim=1)
51
  return f
52
 
53
 
54
  class MeasureZ(nn.Module):
55
+ """Model to measure redshift parameters.
56
+
57
+ This model estimates the parameters of a mixture of Gaussians used for measuring redshift.
58
+
59
+ Attributes:
60
+ ngaussians (int): Number of Gaussian components in the mixture.
61
+ measure_mu (nn.Sequential): Sequential model to measure the mean (mu).
62
+ measure_coeffs (nn.Sequential): Sequential model to measure the mixing coefficients.
63
+ measure_sigma (nn.Sequential): Sequential model to measure the standard deviation (sigma).
64
+ """
65
+
66
+ def __init__(self, num_gauss: int = 10, dropout_prob: float = 0) -> None:
67
+ """Initializes the MeasureZ module.
68
+
69
+ Args:
70
+ num_gauss (int): Number of Gaussian components (default is 10).
71
+ dropout_prob (float): Probability of dropout (default is 0).
72
+ """
73
  super(MeasureZ, self).__init__()
74
 
75
  self.ngaussians = num_gauss
 
94
  nn.Linear(20, num_gauss),
95
  )
96
 
97
+ def forward(
98
+ self, f: torch.Tensor
99
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
100
+ """Forward pass to measure redshift parameters.
101
+
102
+ Args:
103
+ f (torch.Tensor): Input tensor of shape (batch_size, 10).
104
+
105
+ Returns:
106
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
107
+ - mu (torch.Tensor): Mean parameters of shape (batch_size, num_gauss).
108
+ - sigma (torch.Tensor): Standard deviation parameters of shape (batch_size, num_gauss).
109
+ - logmix_coeff (torch.Tensor): Log mixing coefficients of shape (batch_size, num_gauss).
110
+ """
111
  mu = self.measure_mu(f)
112
  sigma = self.measure_sigma(f)
113
  logmix_coeff = self.measure_coeffs(f)
114
 
115
+ # Normalize logmix_coeff to get valid mixture coefficients
116
+ logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, dim=1, keepdim=True)
117
 
118
  return mu, sigma, logmix_coeff
temps/utils.py CHANGED
@@ -4,21 +4,31 @@ import matplotlib.pyplot as plt
4
  from scipy import stats
5
  import torch
6
  from loguru import logger
 
7
 
8
 
9
- def caluclate_eta(df):
10
- return len(df[np.abs(df.zwerr)>0.15])/len(df) *100
11
-
 
12
 
13
- def nmad(data):
 
14
  return 1.4826 * np.median(np.abs(data - np.median(data)))
15
 
16
 
17
- def sigma68(data):
 
18
  return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
19
 
20
 
21
- def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
 
 
 
 
 
 
22
  """
23
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
24
 
@@ -40,7 +50,13 @@ def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num
40
  return mmd_loss
41
 
42
 
43
- def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
 
 
 
 
 
 
44
  """
45
  Compute the kernel matrix based on the chosen kernel type.
46
 
@@ -61,7 +77,7 @@ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
61
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
62
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
63
 
64
- kernel_input = (x - y).pow(2).mean(2)
65
 
66
  if kernel_type == "linear":
67
  kernel_matrix = kernel_input
@@ -80,46 +96,62 @@ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
80
 
81
 
82
  def select_cut(
83
- df, completenss_lim=None, nmad_lim=None, outliers_lim=None, return_df=False
84
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- if (completenss_lim is None) & (nmad_lim is None) & (outliers_lim is None):
87
- raise (ValueError("Select at least one cut"))
88
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
89
  raise ValueError("Select only one cut at a time")
90
 
91
- else:
92
- bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
93
- scatter, eta, cmptnss, nobj = [], [], [], []
94
-
95
- for k in range(len(bin_edges) - 1):
96
- edge_min = bin_edges[k]
97
- edge_max = bin_edges[k + 1]
98
-
99
- df_bin = df[(df.odds > edge_min)]
100
-
101
- cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
102
- scatter.append(nmad(df_bin.zwerr))
103
- eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
104
- nobj.append(len(df_bin))
105
-
106
- dfcuts = pd.DataFrame(
107
- data=np.c_[
108
- np.round(bin_edges[:-1], 5),
109
- np.round(nobj, 1),
110
- np.round(cmptnss, 1),
111
- np.round(scatter, 3),
112
- np.round(eta, 2),
113
- ],
114
- columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
115
- )
116
 
117
  if completenss_lim is not None:
118
  logger.info("Selecting cut based on completeness")
119
  selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
120
 
121
  elif nmad_lim is not None:
122
- logger.info("Selecting cut based on nmad")
123
  selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
124
 
125
  elif outliers_lim is not None:
@@ -127,11 +159,104 @@ def select_cut(
127
  selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
128
 
129
  logger.info(
130
- f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
 
131
  )
132
 
133
  df_cut = df[(df.odds > selected_cut["flagcut"])]
134
- if return_df == True:
 
135
  return df_cut, selected_cut["flagcut"], dfcuts
136
  else:
137
  return selected_cut["flagcut"], dfcuts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from scipy import stats
5
  import torch
6
  from loguru import logger
7
+ from typing import Optional, Tuple, Union
8
 
9
 
10
+ def calculate_eta(df: pd.DataFrame) -> float:
11
+ """Calculate the percentage of outliers in the DataFrame based on zwerr column."""
12
+ return len(df[np.abs(df.zwerr) > 0.15]) / len(df) * 100
13
+
14
 
15
+ def nmad(data: Union[np.ndarray, pd.Series]) -> float:
16
+ """Calculate the normalized median absolute deviation (NMAD) of the data."""
17
  return 1.4826 * np.median(np.abs(data - np.median(data)))
18
 
19
 
20
+ def sigma68(data: Union[np.ndarray, pd.Series]) -> float:
21
+ """Calculate the sigma68 metric, a robust measure of dispersion."""
22
  return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
23
 
24
 
25
+ def maximum_mean_discrepancy(
26
+ x: torch.Tensor,
27
+ y: torch.Tensor,
28
+ kernel_type: str = "rbf",
29
+ kernel_mul: float = 2.0,
30
+ kernel_num: int = 5,
31
+ ) -> torch.Tensor:
32
  """
33
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
34
 
 
50
  return mmd_loss
51
 
52
 
53
+ def compute_kernel(
54
+ x: torch.Tensor,
55
+ y: torch.Tensor,
56
+ kernel_type: str = "rbf",
57
+ kernel_mul: float = 2.0,
58
+ kernel_num: int = 5,
59
+ ) -> torch.Tensor:
60
  """
61
  Compute the kernel matrix based on the chosen kernel type.
62
 
 
77
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
78
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
79
 
80
+ kernel_input = (x - y).pow(2).mean(2)
81
 
82
  if kernel_type == "linear":
83
  kernel_matrix = kernel_input
 
96
 
97
 
98
  def select_cut(
99
+ df: pd.DataFrame,
100
+ completenss_lim: Optional[float] = None,
101
+ nmad_lim: Optional[float] = None,
102
+ outliers_lim: Optional[float] = None,
103
+ return_df: bool = False,
104
+ ) -> Union[Tuple[pd.DataFrame, float, pd.DataFrame], Tuple[float, pd.DataFrame]]:
105
+ """
106
+ Selects a cut based on one of the provided limits (completeness, NMAD, or outliers).
107
+
108
+ Args:
109
+ - df: DataFrame, containing the data
110
+ - completenss_lim: float, optional limit on completeness
111
+ - nmad_lim: float, optional limit on NMAD
112
+ - outliers_lim: float, optional limit on outliers (eta)
113
+ - return_df: bool, whether to return the filtered DataFrame
114
+
115
+ Returns:
116
+ - selected_cut: If return_df is False, returns the cut value and a DataFrame of cuts.
117
+ If return_df is True, returns the filtered DataFrame, cut value, and cuts DataFrame.
118
+ """
119
 
120
+ if (completenss_lim is None) and (nmad_lim is None) and (outliers_lim is None):
121
+ raise ValueError("Select at least one cut")
122
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
123
  raise ValueError("Select only one cut at a time")
124
 
125
+ bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
126
+ scatter, eta, cmptnss, nobj = [], [], [], []
127
+
128
+ for k in range(len(bin_edges) - 1):
129
+ edge_min = bin_edges[k]
130
+ edge_max = bin_edges[k + 1]
131
+
132
+ df_bin = df[(df.odds > edge_min)]
133
+ cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
134
+ scatter.append(nmad(df_bin.zwerr))
135
+ eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
136
+ nobj.append(len(df_bin))
137
+
138
+ dfcuts = pd.DataFrame(
139
+ data=np.c_[
140
+ np.round(bin_edges[:-1], 5),
141
+ np.round(nobj, 1),
142
+ np.round(cmptnss, 1),
143
+ np.round(scatter, 3),
144
+ np.round(eta, 2),
145
+ ],
146
+ columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
147
+ )
 
 
148
 
149
  if completenss_lim is not None:
150
  logger.info("Selecting cut based on completeness")
151
  selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
152
 
153
  elif nmad_lim is not None:
154
+ logger.info("Selecting cut based on NMAD")
155
  selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
156
 
157
  elif outliers_lim is not None:
 
159
  selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
160
 
161
  logger.info(
162
+ f"This cut provides completeness of {selected_cut['completeness']}, "
163
+ f"nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
164
  )
165
 
166
  df_cut = df[(df.odds > selected_cut["flagcut"])]
167
+
168
+ if return_df:
169
  return df_cut, selected_cut["flagcut"], dfcuts
170
  else:
171
  return selected_cut["flagcut"], dfcuts
172
+
173
+ def calculate_pit(model_f: nn.Module,
174
+ model_z: nn.Module,
175
+ input_data: Tensor,
176
+ target_data: Tensor,
177
+ ) -> List[float]:
178
+
179
+ logger.info('Calculating PIT values')
180
+
181
+ pit_list = []
182
+
183
+ model_f = model_f.eval()
184
+ model_f = model_f.to(self.device)
185
+ model_z = model_z.eval()
186
+ model_z = model_z.to(self.device)
187
+
188
+ input_data = input_data.to(self.device)
189
+
190
+
191
+ features = model_f(input_data)
192
+ mu, logsig, logmix_coeff = model_z(features)
193
+
194
+ logsig = torch.clamp(logsig,-6,2)
195
+ sig = torch.exp(logsig)
196
+
197
+ mix_coeff = torch.exp(logmix_coeff)
198
+
199
+ mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
200
+
201
+ for ii in range(len(input_data)):
202
+ pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
203
+ pit_list.append(pit)
204
+
205
+
206
+ return pit_list
207
+
208
+ def calculate_crps(model_f: nn.Module,
209
+ model_z: nn.Module,
210
+ input_data: Tensor,
211
+ target_data: Tensor,
212
+ ) -> List[float]:
213
+ logger.info('Calculating CRPS values')
214
+
215
+ def measure_crps(cdf, t):
216
+ zgrid = np.linspace(0,4,1000)
217
+ Deltaz = zgrid[None,:] - t[:,None]
218
+ DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
219
+ integral = (cdf-DeltaZ_heaviside)**2
220
+ crps_value = integral.sum(1) / 1000
221
+
222
+ return crps_value
223
+
224
+
225
+ crps_list = []
226
+
227
+ model_f = model_f.eval()
228
+ model_f = model_f.to(self.device)
229
+ model_z = model_z.eval()
230
+ model_z = model_z.to(self.device)
231
+
232
+ input_data = input_data.to(self.device)
233
+
234
+
235
+ features = model_f(input_data)
236
+ mu, logsig, logmix_coeff = model_z(features)
237
+ logsig = torch.clamp(logsig,-6,2)
238
+ sig = torch.exp(logsig)
239
+
240
+ mix_coeff = torch.exp(logmix_coeff)
241
+
242
+
243
+ mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
244
+
245
+ z = (mix_coeff * mu).sum(1)
246
+
247
+ x = np.linspace(0, 4, 1000)
248
+ pz = np.zeros(shape=(len(target_data), len(x)))
249
+ for ii in range(len(input_data)):
250
+ for i in range(6):
251
+ pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
252
+
253
+ pz = pz / pz.sum(1)[:,None]
254
+
255
+
256
+ cdf_z = np.cumsum(pz,1)
257
+
258
+ crps_value = measure_crps(cdf_z, target_data)
259
+
260
+
261
+
262
+ return crps_value