Spaces:
Runtime error
Runtime error
Merge pull request #3 from lauracabayol/improve_code
Browse files- notebooks/NMAD.py +3 -1
- pyproject.toml +1 -0
- temps/archive.py +230 -158
- temps/plots.py +260 -225
- temps/temps.py +207 -151
- temps/temps_arch.py +59 -6
- temps/utils.py +165 -40
notebooks/NMAD.py
CHANGED
@@ -61,6 +61,7 @@ eval_methods=True
|
|
61 |
# ### LOAD DATA
|
62 |
|
63 |
# %%
|
|
|
64 |
#define here the directory containing the photometric catalogues
|
65 |
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
66 |
modules_dir = Path('../data/models/')
|
@@ -68,7 +69,6 @@ filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
|
68 |
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
69 |
|
70 |
# %%
|
71 |
-
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
72 |
path_file = parent_dir / filename_valid # Creating the path to the file
|
73 |
hdu_list = fits.open(path_file)
|
74 |
cat = Table(hdu_list[1].data).to_pandas()
|
@@ -158,3 +158,5 @@ plot_photoz(df_list,
|
|
158 |
save=False,
|
159 |
samp='L15'
|
160 |
)
|
|
|
|
|
|
61 |
# ### LOAD DATA
|
62 |
|
63 |
# %%
|
64 |
+
|
65 |
#define here the directory containing the photometric catalogues
|
66 |
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
67 |
modules_dir = Path('../data/models/')
|
|
|
69 |
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
70 |
|
71 |
# %%
|
|
|
72 |
path_file = parent_dir / filename_valid # Creating the path to the file
|
73 |
hdu_list = fits.open(path_file)
|
74 |
cat = Table(hdu_list[1].data).to_pandas()
|
|
|
158 |
save=False,
|
159 |
samp='L15'
|
160 |
)
|
161 |
+
|
162 |
+
# %%
|
pyproject.toml
CHANGED
@@ -30,6 +30,7 @@ dependencies = [
|
|
30 |
"gradio",
|
31 |
"jupytext",
|
32 |
"mkdocs",
|
|
|
33 |
]
|
34 |
|
35 |
classifiers = [
|
|
|
30 |
"gradio",
|
31 |
"jupytext",
|
32 |
"mkdocs",
|
33 |
+
"typing"
|
34 |
]
|
35 |
|
36 |
classifiers = [
|
temps/archive.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
@@ -5,199 +6,263 @@ from astropy.table import Table
|
|
5 |
from scipy.spatial import KDTree
|
6 |
from matplotlib import pyplot as plt
|
7 |
from matplotlib import rcParams
|
8 |
-
from pathlib import Path
|
9 |
from loguru import logger
|
|
|
10 |
|
11 |
-
|
12 |
rcParams["mathtext.fontset"] = "stix"
|
13 |
rcParams["font.family"] = "STIXGeneral"
|
14 |
|
|
|
15 |
class Archive:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
31 |
logger.info("Starting archive")
|
32 |
-
self.flags_kept = flags_kept
|
33 |
-
self.columns_photometry=columns_photometry
|
34 |
-
self.columns_ebv=columns_ebv
|
35 |
|
36 |
-
|
37 |
-
if path_calib.suffix == ".fits":
|
38 |
-
with fits.open(path_calib) as hdu_list:
|
39 |
-
cat = Table(hdu_list[1].data).to_pandas()
|
40 |
-
if path_valid
|
41 |
-
with fits.open(path_valid) as hdu_list:
|
42 |
-
cat_test = Table(hdu_list[1].data).to_pandas()
|
43 |
-
|
44 |
-
elif path_calib.suffix == ".csv":
|
45 |
-
cat = pd.read_csv(path_calib)
|
46 |
-
if path_valid
|
47 |
-
cat_test = pd.read_csv(path_valid)
|
48 |
else:
|
49 |
raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
|
50 |
|
51 |
-
cat = cat.rename(
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
f = catalogue[self.columns_photometry].values
|
100 |
return f
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return color
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return catalogue
|
113 |
|
114 |
-
|
115 |
-
def _clean_photometry(
|
116 |
-
"""
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return catalogue
|
120 |
-
|
121 |
-
def _correct_extinction(
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
ext_correction = catalogue[self.columns_ebv].values
|
124 |
f = f * ext_correction
|
125 |
if return_ext_corr:
|
126 |
return f, ext_correction
|
127 |
else:
|
128 |
return f
|
129 |
-
|
130 |
-
def _select_only_zspec(self,catalogue,cat_flag=None):
|
131 |
-
"""Selects only galaxies with spectroscopic redshift"""
|
132 |
-
if cat_flag=='Calib':
|
133 |
-
catalogue = catalogue[catalogue.specz>0]
|
134 |
-
elif cat_flag=='Valid':
|
135 |
-
catalogue = catalogue[catalogue.specz>0]
|
136 |
-
return catalogue
|
137 |
-
|
138 |
-
def _exclude_only_zspec(self,catalogue):
|
139 |
-
"""Selects only galaxies without spectroscopic redshift"""
|
140 |
-
catalogue = catalogue[(catalogue.specz<0)&(catalogue.photo_z>0)&(catalogue.photo_z<4)]
|
141 |
-
return catalogue
|
142 |
-
|
143 |
-
def _select_L15_sample(self,catalogue):
|
144 |
-
"""Selects only galaxies withoutidx spectroscopic redshift"""
|
145 |
-
catalogue = catalogue[(catalogue.target_z>0)]
|
146 |
-
catalogue = catalogue[(catalogue.target_z<4)]
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
return catalogue
|
150 |
-
|
151 |
-
|
|
|
|
|
152 |
"""Selects only galaxies with spectroscopic redshift"""
|
153 |
if cat_flag=='Calib':
|
154 |
catalogue = catalogue[catalogue.target_z>0]
|
155 |
elif cat_flag=='Valid':
|
156 |
catalogue = catalogue[catalogue.specz>0]
|
157 |
return catalogue
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
#1. , 9.1, 2. , 9.3, 1.4, 3.4, 11.5, 2.4, 13. , 14. , 12.1,
|
162 |
-
#12.5, 13.1, 9.4, 11.1]
|
163 |
-
|
164 |
-
catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
return catalogue
|
167 |
-
|
168 |
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
|
174 |
-
valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
|
175 |
-
|
176 |
-
kdtree = KDTree(gold_sample_radec)
|
177 |
-
distances, indices = kdtree.query(valid_sample_radec, k=1)
|
178 |
-
|
179 |
-
specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
|
180 |
-
|
181 |
-
zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
|
182 |
-
|
183 |
-
catalogue_valid['z_spec_gold'] = zs
|
184 |
|
185 |
-
|
|
|
186 |
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
target_z_train_DA = cat_da['photo_z'].values
|
192 |
|
193 |
|
194 |
if only_zspec:
|
195 |
logger.info("Selecting only galaxies with spectroscopic redshift")
|
196 |
-
catalogue =
|
197 |
-
catalogue =
|
198 |
else:
|
199 |
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
|
200 |
-
catalogue =
|
201 |
|
202 |
|
203 |
self.cat_train=catalogue
|
@@ -230,25 +295,32 @@ class Archive:
|
|
230 |
self.target_z_train = catalogue['target_z'].values
|
231 |
|
232 |
self.VIS_mag_train = catalogue['MAG_VIS'].values
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
if target=='specz':
|
237 |
-
|
238 |
-
|
239 |
-
self.target_z_test =
|
240 |
|
241 |
elif target=='L15':
|
242 |
-
|
243 |
-
self.target_z_test =
|
244 |
|
245 |
|
246 |
-
self.cat_test=
|
247 |
|
248 |
-
f = self._extract_fluxes(
|
249 |
|
250 |
if extinction_corr==True:
|
251 |
-
f = self._correct_extinction(
|
252 |
|
253 |
if convert_colors==True:
|
254 |
col = self._to_colors(f)
|
@@ -257,9 +329,9 @@ class Archive:
|
|
257 |
self.phot_test = f
|
258 |
|
259 |
|
260 |
-
self.VIS_mag_test =
|
261 |
-
|
262 |
-
|
263 |
def get_training_data(self):
|
264 |
return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
|
265 |
|
@@ -267,4 +339,4 @@ class Archive:
|
|
267 |
return self.phot_test, self.target_z_test, self.VIS_mag_test
|
268 |
|
269 |
def get_VIS_mag(self, catalogue):
|
270 |
-
return catalogue[['MAG_VIS']].values
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
from astropy.io import fits
|
|
|
6 |
from scipy.spatial import KDTree
|
7 |
from matplotlib import pyplot as plt
|
8 |
from matplotlib import rcParams
|
9 |
+
from pathlib import Path
|
10 |
from loguru import logger
|
11 |
+
from typing import Optional, Tuple, Union, List
|
12 |
|
13 |
+
# Set matplotlib configuration
|
14 |
rcParams["mathtext.fontset"] = "stix"
|
15 |
rcParams["font.family"] = "STIXGeneral"
|
16 |
|
17 |
+
@dataclass
|
18 |
class Archive:
|
19 |
+
path_calib: Path
|
20 |
+
path_valid: Optional[Path] = None
|
21 |
+
drop_stars: bool = True
|
22 |
+
clean_photometry: bool = True
|
23 |
+
convert_colors: bool = True
|
24 |
+
extinction_corr: bool = True
|
25 |
+
only_zspec: bool = True
|
26 |
+
columns_photometry: List[str] = field(default_factory=lambda: [
|
27 |
+
"FLUX_G_2",
|
28 |
+
"FLUX_R_2",
|
29 |
+
"FLUX_I_2",
|
30 |
+
"FLUX_Z_2",
|
31 |
+
"FLUX_Y_2",
|
32 |
+
"FLUX_J_2",
|
33 |
+
"FLUX_H_2",
|
34 |
+
])
|
35 |
+
columns_ebv: List[str] = field(default_factory=lambda: [
|
36 |
+
"EB_V_corr_FLUX_G",
|
37 |
+
"EB_V_corr_FLUX_R",
|
38 |
+
"EB_V_corr_FLUX_I",
|
39 |
+
"EB_V_corr_FLUX_Z",
|
40 |
+
"EB_V_corr_FLUX_Y",
|
41 |
+
"EB_V_corr_FLUX_J",
|
42 |
+
"EB_V_corr_FLUX_H",
|
43 |
+
])
|
44 |
+
photoz_name: str = "photo_z_L15"
|
45 |
+
specz_name: str = "z_spec_S15"
|
46 |
+
target_test: str = "specz"
|
47 |
+
flags_kept: List[float] = field(default_factory=lambda: [3, 3.1, 3.4, 3.5, 4])
|
48 |
|
49 |
+
def __post_init__(self):
|
50 |
logger.info("Starting archive")
|
|
|
|
|
|
|
51 |
|
52 |
+
# Load data based on the file format
|
53 |
+
if self.path_calib.suffix == ".fits":
|
54 |
+
with fits.open(self.path_calib) as hdu_list:
|
55 |
+
self.cat = Table(hdu_list[1].data).to_pandas()
|
56 |
+
if self.path_valid is not None:
|
57 |
+
with fits.open(self.path_valid) as hdu_list:
|
58 |
+
self.cat_test = Table(hdu_list[1].data).to_pandas()
|
59 |
+
|
60 |
+
elif self.path_calib.suffix == ".csv":
|
61 |
+
self.cat = pd.read_csv(self.path_calib)
|
62 |
+
if self.path_valid is not None:
|
63 |
+
self.cat_test = pd.read_csv(self.path_valid)
|
64 |
else:
|
65 |
raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
|
66 |
|
67 |
+
self.cat = self.cat.rename(
|
68 |
+
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
|
69 |
+
)
|
70 |
+
self.cat_test = self.cat_test.rename(
|
71 |
+
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
|
72 |
+
)
|
73 |
+
|
74 |
+
self.cat = self.cat[(self.cat["specz"] > 0) | (self.cat["photo_z"] > 0)]
|
75 |
+
|
76 |
+
# Apply operations based on the initialized parameters
|
77 |
+
if self.drop_stars:
|
78 |
+
logger.info("Dropping stars...")
|
79 |
+
self.cat = self.cat[self.cat.mu_class_L07 == 1]
|
80 |
+
self.cat_test = self.cat_test[self.cat_test.mu_class_L07 == 1]
|
81 |
+
|
82 |
+
if self.clean_photometry:
|
83 |
+
logger.info("Cleaning photometry...")
|
84 |
+
self.cat = self._clean_photometry(catalogue=self.cat)
|
85 |
+
self.cat_test = self._clean_photometry(catalogue=self.cat_test)
|
86 |
+
|
87 |
+
self.cat = self._set_combined_target(self.cat)
|
88 |
+
self.cat_test = self._set_combined_target(self.cat_test)
|
89 |
+
|
90 |
+
# Apply magnitude and redshift cuts
|
91 |
+
self.cat = self.cat[self.cat.MAG_VIS < 25]
|
92 |
+
self.cat_test = self.cat_test[self.cat_test.MAG_VIS < 25]
|
93 |
+
|
94 |
+
self.cat = self.cat[self.cat.target_z < 5]
|
95 |
+
self.cat_test = self.cat_test[self.cat_test.target_z < 5]
|
96 |
+
|
97 |
+
self._set_training_data(
|
98 |
+
self.cat,
|
99 |
+
self.cat_test,
|
100 |
+
only_zspec=self.only_zspec,
|
101 |
+
extinction_corr=self.extinction_corr,
|
102 |
+
convert_colors=self.convert_colors,
|
103 |
+
)
|
104 |
+
self._set_testing_data(
|
105 |
+
self.cat_test,
|
106 |
+
target=self.target_test,
|
107 |
+
extinction_corr=self.extinction_corr,
|
108 |
+
convert_colors=self.convert_colors,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def _extract_fluxes(self, catalogue: pd.DataFrame) -> np.ndarray:
|
113 |
+
"""Extract fluxes from the given catalogue.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
catalogue (pd.DataFrame): The input catalogue.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
np.ndarray: An array of fluxes.
|
120 |
+
"""
|
121 |
f = catalogue[self.columns_photometry].values
|
122 |
return f
|
123 |
|
124 |
+
@staticmethod
|
125 |
+
def _to_colors(flux: np.ndarray) -> np.ndarray:
|
126 |
+
"""Convert fluxes to colors.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
flux (np.ndarray): The input fluxes.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
np.ndarray: An array of colors.
|
133 |
+
"""
|
134 |
+
color = flux[:, :-1] / flux[:, 1:]
|
135 |
return color
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def _set_combined_target(catalogue: pd.DataFrame) -> pd.DataFrame:
|
139 |
+
"""Set the combined target redshift based on available data.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
catalogue (pd.DataFrame): The input catalogue.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
pd.DataFrame: Updated catalogue with the combined target redshift.
|
146 |
+
"""
|
147 |
+
catalogue["target_z"] = catalogue.apply(
|
148 |
+
lambda row: row["specz"] if row["specz"] > 0 else row["photo_z"], axis=1
|
149 |
+
)
|
150 |
return catalogue
|
151 |
|
152 |
+
@staticmethod
|
153 |
+
def _clean_photometry(catalogue: pd.DataFrame) -> pd.DataFrame:
|
154 |
+
"""Drops all objects with FLAG_PHOT != 0.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
catalogue (pd.DataFrame): The input catalogue.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
pd.DataFrame: Cleaned catalogue.
|
161 |
+
"""
|
162 |
+
catalogue = catalogue[catalogue["FLAG_PHOT"] == 0]
|
163 |
return catalogue
|
164 |
+
|
165 |
+
def _correct_extinction(
|
166 |
+
self, catalogue: pd.DataFrame, f: np.ndarray, return_ext_corr: bool = False
|
167 |
+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
168 |
+
"""Corrects for extinction based on the provided catalogue.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
catalogue (pd.DataFrame): The input catalogue.
|
172 |
+
f (np.ndarray): The flux values to correct.
|
173 |
+
return_ext_corr (bool): Whether to return the extinction correction values.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Corrected fluxes, and optionally the extinction corrections.
|
177 |
+
"""
|
178 |
ext_correction = catalogue[self.columns_ebv].values
|
179 |
f = f * ext_correction
|
180 |
if return_ext_corr:
|
181 |
return f, ext_correction
|
182 |
else:
|
183 |
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
@staticmethod
|
186 |
+
def _select_only_zspec(
|
187 |
+
catalogue: pd.DataFrame, cat_flag: Optional[str] = None
|
188 |
+
) -> pd.DataFrame:
|
189 |
+
"""Selects only galaxies with spectroscopic redshift.
|
190 |
|
191 |
+
Args:
|
192 |
+
catalogue (pd.DataFrame): The input catalogue.
|
193 |
+
cat_flag (Optional[str]): Indicates the catalogue type ('Calib' or 'Valid').
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
pd.DataFrame: Filtered catalogue.
|
197 |
+
"""
|
198 |
+
if cat_flag == "Calib":
|
199 |
+
catalogue = catalogue[catalogue.specz > 0]
|
200 |
+
elif cat_flag == "Valid":
|
201 |
+
catalogue = catalogue[catalogue.specz > 0]
|
202 |
return catalogue
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def take_zspec_and_photoz(catalogue: pd.DataFrame, cat_flag: Optional[str] = None
|
206 |
+
) -> pd.DataFrame:
|
207 |
"""Selects only galaxies with spectroscopic redshift"""
|
208 |
if cat_flag=='Calib':
|
209 |
catalogue = catalogue[catalogue.target_z>0]
|
210 |
elif cat_flag=='Valid':
|
211 |
catalogue = catalogue[catalogue.specz>0]
|
212 |
return catalogue
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def exclude_only_zspec(catalogue: pd.DataFrame) -> pd.DataFrame:
|
216 |
+
"""Selects only galaxies without spectroscopic redshift.
|
217 |
|
218 |
+
Args:
|
219 |
+
catalogue (pd.DataFrame): The input catalogue.
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
Returns:
|
222 |
+
pd.DataFrame: Filtered catalogue.
|
223 |
+
"""
|
224 |
+
catalogue = catalogue[
|
225 |
+
(catalogue.specz < 0) & (catalogue.photo_z > 0) & (catalogue.photo_z < 4)
|
226 |
+
]
|
227 |
return catalogue
|
|
|
228 |
|
229 |
+
@staticmethod
|
230 |
+
def _clean_zspec_sample(catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
|
231 |
+
catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
|
232 |
+
return catalogue
|
233 |
|
234 |
+
@staticmethod
|
235 |
+
def _select_L15_sample(self, catalogue: pd.DataFrame) -> pd.DataFrame:
|
236 |
+
"""Selects only galaxies within a specific redshift range.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
Args:
|
239 |
+
catalogue (pd.DataFrame): The input catalogue.
|
240 |
|
241 |
+
Returns:
|
242 |
+
pd.DataFrame: Filtered catalogue.
|
243 |
+
"""
|
244 |
+
catalogue = catalogue[(catalogue.target_z > 0) & (catalogue.target_z < 3)]
|
245 |
+
return catalogue
|
246 |
|
247 |
+
def _set_training_data(self,
|
248 |
+
catalogue: pd.DataFrame,
|
249 |
+
catalogue_da: pd.DataFrame,
|
250 |
+
only_zspec: bool = True,
|
251 |
+
extinction_corr: bool = True,
|
252 |
+
convert_colors: bool = True
|
253 |
+
)-> None:
|
254 |
+
|
255 |
+
cat_da = Archive.exclude_only_zspec(catalogue_da)
|
256 |
target_z_train_DA = cat_da['photo_z'].values
|
257 |
|
258 |
|
259 |
if only_zspec:
|
260 |
logger.info("Selecting only galaxies with spectroscopic redshift")
|
261 |
+
catalogue = Archive._select_only_zspec(catalogue, cat_flag='Calib')
|
262 |
+
catalogue = Archive._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
|
263 |
else:
|
264 |
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
|
265 |
+
catalogue = Archive.take_zspec_and_photoz(catalogue, cat_flag='Calib')
|
266 |
|
267 |
|
268 |
self.cat_train=catalogue
|
|
|
295 |
self.target_z_train = catalogue['target_z'].values
|
296 |
|
297 |
self.VIS_mag_train = catalogue['MAG_VIS'].values
|
298 |
+
|
299 |
+
|
300 |
+
def _set_testing_data(
|
301 |
+
self,
|
302 |
+
cat_test: pd.DataFrame,
|
303 |
+
target: str = "specz",
|
304 |
+
extinction_corr: bool = True,
|
305 |
+
convert_colors: bool = True,
|
306 |
+
) -> None:
|
307 |
|
308 |
if target=='specz':
|
309 |
+
cat_test = Archive._select_only_zspec(cat_test, cat_flag='Valid')
|
310 |
+
cat_test = Archive._clean_zspec_sample(cat_test)
|
311 |
+
self.target_z_test = cat_test['specz'].values
|
312 |
|
313 |
elif target=='L15':
|
314 |
+
cat_test = self._select_L15_sample(cat_test)
|
315 |
+
self.target_z_test = cat_test['target_z'].values
|
316 |
|
317 |
|
318 |
+
self.cat_test=cat_test
|
319 |
|
320 |
+
f = self._extract_fluxes(cat_test)
|
321 |
|
322 |
if extinction_corr==True:
|
323 |
+
f = self._correct_extinction(cat_test,f)
|
324 |
|
325 |
if convert_colors==True:
|
326 |
col = self._to_colors(f)
|
|
|
329 |
self.phot_test = f
|
330 |
|
331 |
|
332 |
+
self.VIS_mag_test = cat_test['MAG_VIS'].values
|
333 |
+
|
334 |
+
|
335 |
def get_training_data(self):
|
336 |
return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
|
337 |
|
|
|
339 |
return self.phot_test, self.target_z_test, self.VIS_mag_test
|
340 |
|
341 |
def get_VIS_mag(self, catalogue):
|
342 |
+
return catalogue[['MAG_VIS']].values
|
temps/plots.py
CHANGED
@@ -2,127 +2,185 @@ import numpy as np
|
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
from temps.utils import nmad
|
5 |
-
import numpy as np
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
from scipy import stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
plt.rcParams['font.family'] = 'serif'
|
12 |
-
plt.rcParams['font.size'] = 12
|
13 |
-
|
14 |
-
if xvariable == 'VISmag':
|
15 |
-
xvariable_lab = 'VIS'
|
16 |
-
if xvariable == 'zs':
|
17 |
-
xvariable_lab = r'$z_{\rm s}$'
|
18 |
-
|
19 |
-
bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
|
20 |
-
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
21 |
-
#plt.figure(figsize=(6, 5))
|
22 |
-
ls = ['--',':','-']
|
23 |
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), gridspec_kw={'height_ratios': [3, 1]})
|
24 |
-
|
25 |
-
ydata_dict = {}
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
for i, df in enumerate(df_list):
|
28 |
ydata, xlab = [], []
|
29 |
-
|
30 |
label = label_list[i]
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
if label == 'TEMPS':
|
37 |
-
label_lab = 'TEMPS'
|
38 |
-
|
39 |
-
for k in range(len(bin_edges)-1):
|
40 |
-
edge_min = bin_edges[k]
|
41 |
-
edge_max = bin_edges[k+1]
|
42 |
|
|
|
|
|
|
|
43 |
mean_mag = (edge_max + edge_min) / 2
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
51 |
|
52 |
xlab.append(mean_mag)
|
53 |
-
if metric ==
|
54 |
ydata.append(sigma68(df_plot.zwerr))
|
55 |
-
elif metric ==
|
56 |
ydata.append(np.mean(df_plot.zwerr))
|
57 |
-
elif metric ==
|
58 |
ydata.append(nmad(df_plot.zwerr))
|
59 |
-
elif metric ==
|
60 |
-
ydata.append(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
ax1.grid(False)
|
71 |
ax1.legend()
|
72 |
-
|
73 |
-
# Plot ratios
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
ax2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
ax2.grid(True)
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
|
86 |
plt.show()
|
87 |
|
88 |
|
89 |
-
def plot_pz(m, pz, specz):
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
# Add labels and a legend
|
100 |
-
ax.set_xlabel(r'$z$', fontsize = 18)
|
101 |
-
ax.set_ylabel('Probability Density', fontsize=16)
|
102 |
-
ax.legend(fontsize = 18)
|
103 |
|
104 |
-
|
105 |
-
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
plt.xticks(fontsize=12)
|
118 |
plt.yticks(fontsize=12)
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
plt.xlabel(r'Redshift', fontsize=14)
|
121 |
-
plt.ylabel('Counts', fontsize=14)
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
"""
|
127 |
Plot the Self-Organizing Map (SOM) data.
|
128 |
|
@@ -135,182 +193,159 @@ def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
|
|
135 |
Returns:
|
136 |
None
|
137 |
"""
|
138 |
-
plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap=
|
139 |
-
plt.colorbar(label=f
|
140 |
-
plt.xlabel(r
|
141 |
-
plt.ylabel(r
|
142 |
plt.show()
|
143 |
|
144 |
-
|
145 |
-
def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', labels=None, save =True):
|
146 |
-
#plot properties
|
147 |
-
plt.rcParams['font.family'] = 'serif'
|
148 |
-
plt.rcParams['font.size'] = 12
|
149 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
150 |
-
kwargs=dict(bins=30, histtype='step', density=True, range=(0,1))
|
151 |
-
cmap = plt.get_cmap('Dark2')
|
152 |
-
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
165 |
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
#
|
170 |
-
ax.
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
174 |
|
175 |
-
|
176 |
-
ax.
|
177 |
-
|
178 |
-
plt.savefig(f'{sample}_PIT.pdf', bbox_inches='tight')
|
179 |
|
180 |
-
|
|
|
181 |
plt.show()
|
182 |
-
|
183 |
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
plt.rcParams['font.family'] = 'serif'
|
191 |
-
plt.rcParams['font.size'] = 16
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
ax.axvline(zp_mean, color=color, linestyle='--', lw=2)
|
214 |
-
|
215 |
-
ax.set_ylabel(f'Frequency', fontsize=14)
|
216 |
-
ax.grid(False)
|
217 |
-
ax.set_xlim(0, 3.5)
|
218 |
-
|
219 |
-
axs[-1].set_xlabel(f'$z$', fontsize=18)
|
220 |
-
|
221 |
-
if save:
|
222 |
-
plt.savefig(f'nz_hist.pdf', dpi=300, bbox_inches='tight')
|
223 |
-
|
224 |
plt.show()
|
225 |
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
def plot_crps(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
# Create a figure and axis
|
231 |
-
#plot properties
|
232 |
-
plt.rcParams[
|
233 |
-
plt.rcParams[
|
234 |
fig, ax = plt.subplots(figsize=(8, 6))
|
235 |
-
cmap = plt.get_cmap(
|
236 |
|
237 |
-
kwargs=dict(bins=50, histtype=
|
238 |
|
239 |
# Create a histogram
|
240 |
-
hist, bins, _ = ax.hist(
|
|
|
|
|
241 |
if crps_list_2 is not None:
|
242 |
-
hist, bins, _ = ax.hist(
|
|
|
|
|
243 |
if crps_list_3 is not None:
|
244 |
-
hist, bins, _ = ax.hist(
|
|
|
|
|
245 |
|
246 |
# Add labels and a title
|
247 |
-
ax.set_xlabel(
|
248 |
-
ax.set_ylabel(
|
249 |
|
250 |
# Add grid lines
|
251 |
-
ax.grid(True, linestyle=
|
252 |
|
253 |
# Customize the x-axis
|
254 |
ax.set_xlim(0, 0.5)
|
255 |
|
256 |
# Make ticks larger
|
257 |
-
ax.tick_params(axis=
|
258 |
|
259 |
# Calculate the mean CRPS value
|
260 |
mean_crps_1 = round(np.nanmean(crps_list_1), 2)
|
261 |
mean_crps_2 = round(np.nanmean(crps_list_2), 2)
|
262 |
mean_crps_3 = round(np.nanmean(crps_list_3), 2)
|
263 |
|
264 |
-
|
265 |
# Add the mean CRPS value at the top-left corner
|
266 |
-
ax.annotate(
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
# Show the plot
|
275 |
plt.show()
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
def plot_nz(df, bins=np.arange(0,5,0.2)):
|
280 |
-
kwargs=dict( bins=bins,alpha=0.5)
|
281 |
-
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
|
282 |
-
counts, _, =np.histogram(df.z.values, bins=bins)
|
283 |
-
|
284 |
-
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
|
285 |
-
|
286 |
-
#plt.legend(fontsize=14)
|
287 |
-
plt.xlabel(r'Redshift', fontsize=14)
|
288 |
-
plt.ylabel(r'Counts', fontsize=14)
|
289 |
-
plt.yscale('log')
|
290 |
-
|
291 |
-
plt.show()
|
292 |
-
|
293 |
-
return
|
294 |
-
|
295 |
-
|
296 |
-
def plot_scatter(df, sample='specz', save=True):
|
297 |
-
# Calculate the point density
|
298 |
-
xy = np.vstack([df.zs.values,df.z.values])
|
299 |
-
zd = gaussian_kde(xy)(xy)
|
300 |
-
|
301 |
-
fig, ax = plt.subplots()
|
302 |
-
plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
|
303 |
-
plt.xlim(0,5)
|
304 |
-
plt.ylim(0,5)
|
305 |
-
|
306 |
-
plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
|
307 |
-
plt.ylabel('$z$', fontsize = 14)
|
308 |
-
|
309 |
-
plt.xticks(fontsize = 12)
|
310 |
-
plt.yticks(fontsize = 12)
|
311 |
-
|
312 |
-
if save==True:
|
313 |
-
plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
|
314 |
-
|
315 |
-
plt.show()
|
316 |
-
|
|
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
from temps.utils import nmad
|
|
|
|
|
5 |
from scipy import stats
|
6 |
+
from typing import List, Optional, Dict
|
7 |
+
|
8 |
+
|
9 |
+
def plot_photoz(
|
10 |
+
df_list: List[pd.DataFrame],
|
11 |
+
nbins: int,
|
12 |
+
xvariable: str,
|
13 |
+
metric: str,
|
14 |
+
type_bin: str = "bin",
|
15 |
+
label_list: Optional[List[str]] = None,
|
16 |
+
samp: str = "zs",
|
17 |
+
save: bool = False,
|
18 |
+
) -> None:
|
19 |
+
"""
|
20 |
+
Plot photo-z metrics for multiple dataframes.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
- df_list (List[pd.DataFrame]): List of dataframes containing data for plotting.
|
24 |
+
- nbins (int): Number of bins for the histogram.
|
25 |
+
- xvariable (str): Variable to plot on the x-axis.
|
26 |
+
- metric (str): Metric to plot (e.g., 'sig68', 'bias', 'nmad', 'outliers').
|
27 |
+
- type_bin (str, optional): Type of binning ('bin' or 'cum'). Default is 'bin'.
|
28 |
+
- label_list (Optional[List[str]], optional): List of labels for each dataframe. Default is None.
|
29 |
+
- samp (str, optional): Sample label for saving. Default is 'zs'.
|
30 |
+
- save (bool, optional): If True, save the plot to a file. Default is False.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
None
|
34 |
+
"""
|
35 |
+
# Plot properties
|
36 |
+
plt.rcParams["font.family"] = "serif"
|
37 |
+
plt.rcParams["font.size"] = 12
|
38 |
|
39 |
+
# Set x-axis label based on variable
|
40 |
+
xvariable_lab = "VIS" if xvariable == "VISmag" else r"$z_{\rm s}$"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
# Calculate bin edges
|
43 |
+
bin_edges = stats.mstats.mquantiles(
|
44 |
+
df_list[0][xvariable].values, np.linspace(0.05, 1, nbins)
|
45 |
+
)
|
46 |
+
cmap = plt.get_cmap("Dark2")
|
47 |
+
|
48 |
+
# Create subplots
|
49 |
+
fig, (ax1, ax2) = plt.subplots(
|
50 |
+
2, 1, figsize=(8, 8), gridspec_kw={"height_ratios": [3, 1]}
|
51 |
+
)
|
52 |
+
ydata_dict: Dict[str, List[float]] = {}
|
53 |
+
|
54 |
+
# Loop through dataframes and calculate metrics
|
55 |
for i, df in enumerate(df_list):
|
56 |
ydata, xlab = [], []
|
57 |
+
|
58 |
label = label_list[i]
|
59 |
+
label_lab = {
|
60 |
+
"zs": r"$z_{\rm s}$",
|
61 |
+
"zs+L15": r"$z_{\rm s}$+L15",
|
62 |
+
"TEMPS": "TEMPS",
|
63 |
+
}.get(label, label)
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
for k in range(len(bin_edges) - 1):
|
66 |
+
edge_min = bin_edges[k]
|
67 |
+
edge_max = bin_edges[k + 1]
|
68 |
mean_mag = (edge_max + edge_min) / 2
|
69 |
|
70 |
+
df_plot = (
|
71 |
+
df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
|
72 |
+
if type_bin == "bin"
|
73 |
+
else df[(df[xvariable] < edge_max)]
|
74 |
+
)
|
|
|
75 |
|
76 |
xlab.append(mean_mag)
|
77 |
+
if metric == "sig68":
|
78 |
ydata.append(sigma68(df_plot.zwerr))
|
79 |
+
elif metric == "bias":
|
80 |
ydata.append(np.mean(df_plot.zwerr))
|
81 |
+
elif metric == "nmad":
|
82 |
ydata.append(nmad(df_plot.zwerr))
|
83 |
+
elif metric == "outliers":
|
84 |
+
ydata.append(
|
85 |
+
len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot) * 100
|
86 |
+
)
|
87 |
+
|
88 |
+
ydata_dict[f"{i}"] = ydata
|
89 |
+
color = cmap(i)
|
90 |
+
ax1.plot(
|
91 |
+
xlab,
|
92 |
+
ydata,
|
93 |
+
marker=".",
|
94 |
+
lw=1,
|
95 |
+
label=label_lab,
|
96 |
+
color=color,
|
97 |
+
ls=["--", ":", "-"][i],
|
98 |
+
)
|
99 |
+
|
100 |
+
ax1.set_ylabel(f"{metric} $[\Delta z]$", fontsize=18)
|
101 |
ax1.grid(False)
|
102 |
ax1.legend()
|
103 |
+
|
104 |
+
# Plot ratios
|
105 |
+
ax2.plot(
|
106 |
+
xlab,
|
107 |
+
np.array(ydata_dict["1"]) / np.array(ydata_dict["0"]),
|
108 |
+
marker=".",
|
109 |
+
color=cmap(1),
|
110 |
+
)
|
111 |
+
ax2.plot(
|
112 |
+
xlab,
|
113 |
+
np.array(ydata_dict["2"]) / np.array(ydata_dict["0"]),
|
114 |
+
marker=".",
|
115 |
+
color=cmap(2),
|
116 |
+
)
|
117 |
+
ax2.set_ylabel(r"Method $X$ / $z_{\rm z}$", fontsize=14)
|
118 |
+
ax2.set_xlabel(f"{xvariable_lab}", fontsize=16)
|
119 |
ax2.grid(True)
|
120 |
|
121 |
+
if save:
|
122 |
+
plt.savefig(f"{metric}_{xvariable}_{samp}.pdf", dpi=300, bbox_inches="tight")
|
|
|
123 |
plt.show()
|
124 |
|
125 |
|
126 |
+
def plot_pz(m: int, pz: np.ndarray, specz: float) -> None:
|
127 |
+
"""
|
128 |
+
Plot the Probability Density Function (PDF) for a given model and compare it with the spectroscopic redshift.
|
129 |
|
130 |
+
Parameters:
|
131 |
+
- m (int): Index for the model.
|
132 |
+
- pz (np.ndarray): Probability density function values.
|
133 |
+
- specz (float): Spectroscopic redshift value.
|
134 |
|
135 |
+
Returns:
|
136 |
+
None
|
137 |
+
"""
|
138 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
139 |
+
ax.plot(np.linspace(0, 4, 1000), pz[m], label="PDF", color="navy")
|
140 |
+
ax.axvline(specz[m], color="black", linestyle="--", label=r"$z_{\rm s}$")
|
141 |
+
ax.set_xlabel(r"$z$", fontsize=18)
|
142 |
+
ax.set_ylabel("Probability Density", fontsize=16)
|
143 |
+
ax.legend(fontsize=18)
|
144 |
+
plt.show()
|
145 |
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
def plot_zdistribution(archive, plot_test: bool = False, bins: int = 50) -> None:
|
148 |
+
"""
|
149 |
+
Plot the distribution of redshifts for training and optionally test samples.
|
150 |
|
151 |
+
Parameters:
|
152 |
+
- archive: Data archive object containing the training data.
|
153 |
+
- plot_test (bool, optional): If True, plot test sample distribution. Default is False.
|
154 |
+
- bins (int, optional): Number of histogram bins. Default is 50.
|
155 |
|
156 |
+
Returns:
|
157 |
+
None
|
158 |
+
"""
|
159 |
+
_, _, specz = archive.get_training_data()
|
160 |
+
plt.hist(specz, bins=bins, histtype="step", color="navy", label=r"Training sample")
|
161 |
|
162 |
+
if plot_test:
|
163 |
+
_, _, specz_test = archive.get_training_data()
|
164 |
+
plt.hist(
|
165 |
+
specz_test,
|
166 |
+
bins=bins,
|
167 |
+
histtype="step",
|
168 |
+
color="goldenrod",
|
169 |
+
label=r"Test sample",
|
170 |
+
linestyle="--",
|
171 |
+
)
|
172 |
|
173 |
plt.xticks(fontsize=12)
|
174 |
plt.yticks(fontsize=12)
|
175 |
+
plt.xlabel(r"Redshift", fontsize=14)
|
176 |
+
plt.ylabel("Counts", fontsize=14)
|
177 |
+
plt.legend()
|
178 |
+
plt.show()
|
179 |
|
|
|
|
|
180 |
|
181 |
+
def plot_som_map(
|
182 |
+
som_data: np.ndarray, plot_arg: str = "z", vmin: float = 0, vmax: float = 1
|
183 |
+
) -> None:
|
184 |
"""
|
185 |
Plot the Self-Organizing Map (SOM) data.
|
186 |
|
|
|
193 |
Returns:
|
194 |
None
|
195 |
"""
|
196 |
+
plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap="viridis")
|
197 |
+
plt.colorbar(label=f"{plot_arg}")
|
198 |
+
plt.xlabel(r"$x$ [pixel]", fontsize=14)
|
199 |
+
plt.ylabel(r"$y$ [pixel]", fontsize=14)
|
200 |
plt.show()
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
def plot_PIT(
|
204 |
+
pit_list_1: List[float],
|
205 |
+
pit_list_2: Optional[List[float]] = None,
|
206 |
+
pit_list_3: Optional[List[float]] = None,
|
207 |
+
sample: str = "specz",
|
208 |
+
labels: Optional[List[str]] = None,
|
209 |
+
save: bool = True,
|
210 |
+
) -> None:
|
211 |
+
"""
|
212 |
+
Plot Probability Integral Transform (PIT) values for given lists.
|
213 |
|
214 |
+
Parameters:
|
215 |
+
- pit_list_1 (List[float]): First list of PIT values.
|
216 |
+
- pit_list_2 (Optional[List[float]], optional): Second list of PIT values. Default is None.
|
217 |
+
- pit_list_3 (Optional[List[float]], optional): Third list of PIT values. Default is None.
|
218 |
+
- sample (str, optional): Sample label for saving. Default is 'specz'.
|
219 |
+
- labels (Optional[List[str]], optional): List of labels for each PIT list. Default is None.
|
220 |
+
- save (bool, optional): If True, save the plot to a file. Default is True.
|
221 |
|
222 |
+
Returns:
|
223 |
+
None
|
224 |
+
"""
|
225 |
+
plt.rcParams["font.family"] = "serif"
|
226 |
+
plt.rcParams["font.size"] = 12
|
227 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
228 |
+
kwargs = dict(bins=30, histtype="step", density=True, range=(0, 1))
|
229 |
+
cmap = plt.get_cmap("Dark2")
|
230 |
|
231 |
+
# Create a histogram
|
232 |
+
ax.hist(pit_list_1, color=cmap(0), linestyle="--", **kwargs, label=labels[0])
|
233 |
+
if pit_list_2 is not None:
|
234 |
+
ax.hist(pit_list_2, color=cmap(1), linestyle="--", **kwargs, label=labels[1])
|
235 |
+
if pit_list_3 is not None:
|
236 |
+
ax.hist(pit_list_3, color=cmap(2), linestyle="--", **kwargs, label=labels[2])
|
237 |
|
238 |
+
ax.set_xlabel("PIT values", fontsize=14)
|
239 |
+
ax.set_ylabel("Normalized Counts", fontsize=14)
|
240 |
+
ax.legend(fontsize=12)
|
|
|
241 |
|
242 |
+
if save:
|
243 |
+
plt.savefig(f"PIT_{sample}.pdf", dpi=300, bbox_inches="tight")
|
244 |
plt.show()
|
|
|
245 |
|
246 |
|
247 |
+
def plot_outlier_ratio(
|
248 |
+
outliers: np.ndarray, num_samp: int = 100, plot_mean: bool = True
|
249 |
+
) -> None:
|
250 |
+
"""
|
251 |
+
Plot the outlier ratio as a function of the number of samples.
|
|
|
|
|
252 |
|
253 |
+
Parameters:
|
254 |
+
- outliers (np.ndarray): Outlier ratio data.
|
255 |
+
- num_samp (int, optional): Number of samples for plotting. Default is 100.
|
256 |
+
- plot_mean (bool, optional): If True, plot the mean of outliers. Default is True.
|
257 |
|
258 |
+
Returns:
|
259 |
+
None
|
260 |
+
"""
|
261 |
+
plt.figure(figsize=(10, 6))
|
262 |
+
plt.plot(np.arange(1, num_samp + 1), outliers[:num_samp], label="Outlier Ratio")
|
263 |
+
|
264 |
+
if plot_mean:
|
265 |
+
plt.axhline(
|
266 |
+
np.mean(outliers), color="red", linestyle="--", label="Mean Outlier Ratio"
|
267 |
+
)
|
268 |
+
|
269 |
+
plt.xlabel("Number of Samples", fontsize=14)
|
270 |
+
plt.ylabel("Outlier Ratio", fontsize=14)
|
271 |
+
plt.legend()
|
272 |
+
plt.grid()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
plt.show()
|
274 |
|
|
|
|
|
275 |
|
276 |
+
def plot_crps(
|
277 |
+
crps_list_1: List[float],
|
278 |
+
crps_list_2: Optional[List[float]] = None,
|
279 |
+
crps_list_3: Optional[List[float]] = None,
|
280 |
+
label: Optional[List[str]] = None,
|
281 |
+
sample: str = "specz",
|
282 |
+
save: bool = True,
|
283 |
+
) -> None:
|
284 |
# Create a figure and axis
|
285 |
+
# plot properties
|
286 |
+
plt.rcParams["font.family"] = "serif"
|
287 |
+
plt.rcParams["font.size"] = 12
|
288 |
fig, ax = plt.subplots(figsize=(8, 6))
|
289 |
+
cmap = plt.get_cmap("Dark2")
|
290 |
|
291 |
+
kwargs = dict(bins=50, histtype="step", density=True, range=(0, 1))
|
292 |
|
293 |
# Create a histogram
|
294 |
+
hist, bins, _ = ax.hist(
|
295 |
+
crps_list_1, color=cmap(0), ls="--", **kwargs, label=labels[0]
|
296 |
+
)
|
297 |
if crps_list_2 is not None:
|
298 |
+
hist, bins, _ = ax.hist(
|
299 |
+
crps_list_2, color=cmap(1), ls=":", **kwargs, label=labels[1]
|
300 |
+
)
|
301 |
if crps_list_3 is not None:
|
302 |
+
hist, bins, _ = ax.hist(
|
303 |
+
crps_list_3, color=cmap(2), ls="-", **kwargs, label=labels[2]
|
304 |
+
)
|
305 |
|
306 |
# Add labels and a title
|
307 |
+
ax.set_xlabel("CRPS Scores", fontsize=18)
|
308 |
+
ax.set_ylabel("Frequency", fontsize=18)
|
309 |
|
310 |
# Add grid lines
|
311 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
312 |
|
313 |
# Customize the x-axis
|
314 |
ax.set_xlim(0, 0.5)
|
315 |
|
316 |
# Make ticks larger
|
317 |
+
ax.tick_params(axis="both", which="major", labelsize=14)
|
318 |
|
319 |
# Calculate the mean CRPS value
|
320 |
mean_crps_1 = round(np.nanmean(crps_list_1), 2)
|
321 |
mean_crps_2 = round(np.nanmean(crps_list_2), 2)
|
322 |
mean_crps_3 = round(np.nanmean(crps_list_3), 2)
|
323 |
|
|
|
324 |
# Add the mean CRPS value at the top-left corner
|
325 |
+
ax.annotate(
|
326 |
+
f"Mean CRPS {labels[0]}: {mean_crps_1}",
|
327 |
+
xy=(0.57, 0.9),
|
328 |
+
xycoords="axes fraction",
|
329 |
+
fontsize=14,
|
330 |
+
color=cmap(0),
|
331 |
+
)
|
332 |
+
ax.annotate(
|
333 |
+
f"Mean CRPS {labels[1]}: {mean_crps_2}",
|
334 |
+
xy=(0.57, 0.85),
|
335 |
+
xycoords="axes fraction",
|
336 |
+
fontsize=14,
|
337 |
+
color=cmap(1),
|
338 |
+
)
|
339 |
+
ax.annotate(
|
340 |
+
f"Mean CRPS {labels[2]}: {mean_crps_3}",
|
341 |
+
xy=(0.57, 0.8),
|
342 |
+
xycoords="axes fraction",
|
343 |
+
fontsize=14,
|
344 |
+
color=cmap(2),
|
345 |
+
)
|
346 |
+
|
347 |
+
if save == True:
|
348 |
+
plt.savefig(f"{sample}_CRPS.pdf", bbox_inches="tight")
|
349 |
|
350 |
# Show the plot
|
351 |
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temps/temps.py
CHANGED
@@ -6,38 +6,63 @@ from torch.optim import lr_scheduler
|
|
6 |
from loguru import logger
|
7 |
import pandas as pd
|
8 |
from scipy.stats import norm
|
9 |
-
|
10 |
-
from tqdm import tqdm
|
|
|
11 |
|
12 |
from temps.utils import maximum_mean_discrepancy
|
13 |
|
14 |
|
|
|
15 |
class TempsModule:
|
16 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
def
|
19 |
self,
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
def _get_dataloaders(
|
38 |
-
self, input_data, target_data, input_data_da=None, val_fraction=0.1
|
39 |
-
):
|
40 |
-
"""Create training and validation dataloaders."""
|
41 |
input_data = torch.Tensor(input_data)
|
42 |
target_data = torch.Tensor(target_data)
|
43 |
input_data_da = (
|
@@ -47,10 +72,17 @@ class TempsModule:
|
|
47 |
)
|
48 |
|
49 |
dataset = TensorDataset(input_data, input_data_da, target_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
train_dataset, val_dataset = torch.utils.data.random_split(
|
51 |
dataset,
|
52 |
-
[
|
53 |
)
|
|
|
54 |
loader_train = DataLoader(
|
55 |
train_dataset, batch_size=self.batch_size, shuffle=True
|
56 |
)
|
@@ -58,8 +90,24 @@ class TempsModule:
|
|
58 |
|
59 |
return loader_train, loader_val
|
60 |
|
61 |
-
def _loss_function(
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
log_prob = (
|
64 |
logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
|
65 |
)
|
@@ -67,28 +115,55 @@ class TempsModule:
|
|
67 |
loss = -log_prob.mean()
|
68 |
return loss
|
69 |
|
70 |
-
def _loss_function_da(self, f1, f2):
|
71 |
-
"""Compute the KL divergence loss for domain adaptation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
73 |
loss = kl_loss(f1, f2)
|
74 |
return torch.log(loss)
|
75 |
|
76 |
-
def _to_numpy(self, x):
|
77 |
-
"""Convert a tensor to a NumPy array.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return x.detach().cpu().numpy()
|
79 |
|
80 |
def train(
|
81 |
self,
|
82 |
-
input_data,
|
83 |
-
input_data_da,
|
84 |
-
target_data,
|
85 |
-
nepochs=10,
|
86 |
-
step_size=100,
|
87 |
-
val_fraction=0.1,
|
88 |
-
lr=1e-3,
|
89 |
-
weight_decay=0,
|
90 |
-
):
|
91 |
-
"""Train the models using provided data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
self.model_z.train()
|
93 |
self.model_f.train()
|
94 |
|
@@ -157,8 +232,11 @@ class TempsModule:
|
|
157 |
f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
|
158 |
)
|
159 |
|
160 |
-
def _validate(
|
|
|
|
|
161 |
"""Validate the model on the validation dataset."""
|
|
|
162 |
self.model_z.eval()
|
163 |
self.model_f.eval()
|
164 |
_loss_validation = []
|
@@ -180,15 +258,49 @@ class TempsModule:
|
|
180 |
|
181 |
return _loss_validation
|
182 |
|
183 |
-
def get_features(self, input_data):
|
184 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
self.model_f.eval()
|
186 |
input_data = input_data.to(self.device)
|
187 |
features = self.model_f(input_data)
|
188 |
return self._to_numpy(features)
|
189 |
|
190 |
-
def get_pz(
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
logger.info("Predicting photo-z for the input galaxies...")
|
193 |
self.model_z.eval().to(self.device)
|
194 |
self.model_f.eval().to(self.device)
|
@@ -206,6 +318,7 @@ class TempsModule:
|
|
206 |
+ (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
|
207 |
)
|
208 |
|
|
|
209 |
mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
|
210 |
|
211 |
if return_pz:
|
@@ -214,118 +327,61 @@ class TempsModule:
|
|
214 |
else:
|
215 |
return self._to_numpy(z), self._to_numpy(zerr)
|
216 |
|
217 |
-
def _calculate_pdf(
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
zgrid = np.linspace(0, 5, 1000)
|
220 |
pz = np.zeros((len(z), len(zgrid)))
|
221 |
|
222 |
for ii in range(len(z)):
|
223 |
for i in range(self.ngaussians):
|
224 |
-
pz[ii] += mix_coeff[ii, i] * norm.pdf(
|
225 |
-
zgrid, mu[ii, i], sig[ii, i]
|
226 |
-
)
|
227 |
|
228 |
if return_flag:
|
229 |
logger.info("Calculating and returning ODDS")
|
230 |
pz /= pz.sum(axis=1, keepdims=True)
|
231 |
return self._calculate_odds(z, pz, zgrid)
|
232 |
-
return
|
233 |
-
|
234 |
-
def _calculate_odds(
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
self.model_f = self.model_f.eval()
|
256 |
-
self.model_f = self.model_f.to(self.device)
|
257 |
-
self.model_z = self.model_z.eval()
|
258 |
-
self.model_z = self.model_z.to(self.device)
|
259 |
-
|
260 |
-
input_data = input_data.to(self.device)
|
261 |
-
|
262 |
-
|
263 |
-
features = self.model_f(input_data)
|
264 |
-
mu, logsig, logmix_coeff = self.model_z(features)
|
265 |
-
|
266 |
-
logsig = torch.clamp(logsig,-6,2)
|
267 |
-
sig = torch.exp(logsig)
|
268 |
-
|
269 |
-
mix_coeff = torch.exp(logmix_coeff)
|
270 |
-
|
271 |
-
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
272 |
-
|
273 |
-
for ii in range(len(input_data)):
|
274 |
-
pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
|
275 |
-
pit_list.append(pit)
|
276 |
-
|
277 |
-
|
278 |
-
return pit_list
|
279 |
-
|
280 |
-
def calculate_crps(self, input_data, target_data):
|
281 |
-
logger.info('Calculating CRPS values')
|
282 |
-
|
283 |
-
def measure_crps(cdf, t):
|
284 |
-
zgrid = np.linspace(0,4,1000)
|
285 |
-
Deltaz = zgrid[None,:] - t[:,None]
|
286 |
-
DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
|
287 |
-
integral = (cdf-DeltaZ_heaviside)**2
|
288 |
-
crps_value = integral.sum(1) / 1000
|
289 |
-
|
290 |
-
return crps_value
|
291 |
-
|
292 |
-
|
293 |
-
crps_list = []
|
294 |
-
|
295 |
-
self.model_f = self.model_f.eval()
|
296 |
-
self.model_f = self.model_f.to(self.device)
|
297 |
-
self.model_z = self.model_z.eval()
|
298 |
-
self.model_z = self.model_z.to(self.device)
|
299 |
-
|
300 |
-
input_data = input_data.to(self.device)
|
301 |
-
|
302 |
-
|
303 |
-
features = self.model_f(input_data)
|
304 |
-
mu, logsig, logmix_coeff = self.model_z(features)
|
305 |
-
logsig = torch.clamp(logsig,-6,2)
|
306 |
-
sig = torch.exp(logsig)
|
307 |
-
|
308 |
-
mix_coeff = torch.exp(logmix_coeff)
|
309 |
-
|
310 |
-
|
311 |
-
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
312 |
-
|
313 |
-
z = (mix_coeff * mu).sum(1)
|
314 |
-
|
315 |
-
x = np.linspace(0, 4, 1000)
|
316 |
-
pz = np.zeros(shape=(len(target_data), len(x)))
|
317 |
-
for ii in range(len(input_data)):
|
318 |
-
for i in range(6):
|
319 |
-
pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
320 |
-
|
321 |
-
pz = pz / pz.sum(1)[:,None]
|
322 |
-
|
323 |
-
|
324 |
-
cdf_z = np.cumsum(pz,1)
|
325 |
-
|
326 |
-
crps_value = measure_crps(cdf_z, target_data)
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
return crps_value
|
331 |
-
|
|
|
6 |
from loguru import logger
|
7 |
import pandas as pd
|
8 |
from scipy.stats import norm
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from tqdm import tqdm
|
11 |
+
from typing import Optional, Tuple, List, Union
|
12 |
|
13 |
from temps.utils import maximum_mean_discrepancy
|
14 |
|
15 |
|
16 |
+
@dataclass
|
17 |
class TempsModule:
|
18 |
+
"""Attributes:
|
19 |
+
model_f (nn.Module): The feature extraction model.
|
20 |
+
model_z (nn.Module): The model for predicting z values.
|
21 |
+
batch_size (int): Size of each batch for training. Default is 100.
|
22 |
+
rejection_param (int): Parameter for rejection sampling. Default is 1.
|
23 |
+
da (bool): Flag for enabling domain adaptation. Default is True.
|
24 |
+
verbose (bool): Flag for verbose logging. Default is False.
|
25 |
+
device (torch.device): Device to run the model on (CPU or GPU).
|
26 |
+
ngaussians (int): Number of Gaussian components in the mixture model.
|
27 |
+
"""
|
28 |
+
|
29 |
+
model_f: nn.Module
|
30 |
+
model_z: nn.Module
|
31 |
+
batch_size: int = 100
|
32 |
+
rejection_param: int = 1
|
33 |
+
da: bool = True
|
34 |
+
verbose: bool = False
|
35 |
+
device: torch.device = field(init=False)
|
36 |
+
ngaussians: int = field(init=False)
|
37 |
+
|
38 |
+
def __post_init__(self) -> None:
|
39 |
+
"""Post-initialization for setting up additional attributes."""
|
40 |
+
self.device: torch.device = torch.device(
|
41 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
)
|
43 |
+
self.ngaussians: int = (
|
44 |
+
self.model_z.ngaussians
|
45 |
+
) # Assuming ngaussians is an integer
|
46 |
|
47 |
+
def _get_dataloaders(
|
48 |
self,
|
49 |
+
input_data: np.ndarray,
|
50 |
+
target_data: np.ndarray,
|
51 |
+
input_data_da: Optional[np.ndarray] = None,
|
52 |
+
val_fraction: float = 0.1,
|
53 |
+
) -> Tuple[DataLoader, DataLoader]:
|
54 |
+
"""Create training and validation dataloaders.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
input_data (np.ndarray): The input features for training.
|
58 |
+
target_data (np.ndarray): The target outputs for training.
|
59 |
+
input_data_da (Optional[np.ndarray]): Input data for domain adaptation (if any).
|
60 |
+
val_fraction (float): Fraction of data to use for validation. Default is 0.1.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tuple[DataLoader, DataLoader]: Training and validation data loaders.
|
64 |
+
"""
|
65 |
|
|
|
|
|
|
|
|
|
66 |
input_data = torch.Tensor(input_data)
|
67 |
target_data = torch.Tensor(target_data)
|
68 |
input_data_da = (
|
|
|
72 |
)
|
73 |
|
74 |
dataset = TensorDataset(input_data, input_data_da, target_data)
|
75 |
+
|
76 |
+
# Calculate sizes for training and validation sets
|
77 |
+
total_size = len(dataset)
|
78 |
+
val_size = int(total_size * val_fraction)
|
79 |
+
train_size = total_size - val_size
|
80 |
+
|
81 |
train_dataset, val_dataset = torch.utils.data.random_split(
|
82 |
dataset,
|
83 |
+
[train_size, val_size],
|
84 |
)
|
85 |
+
|
86 |
loader_train = DataLoader(
|
87 |
train_dataset, batch_size=self.batch_size, shuffle=True
|
88 |
)
|
|
|
90 |
|
91 |
return loader_train, loader_val
|
92 |
|
93 |
+
def _loss_function(
|
94 |
+
self,
|
95 |
+
mean: torch.Tensor,
|
96 |
+
std: torch.Tensor,
|
97 |
+
logmix: torch.Tensor,
|
98 |
+
true: torch.Tensor,
|
99 |
+
) -> torch.Tensor:
|
100 |
+
"""Compute the loss function for the model.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
mean (torch.Tensor): Mean values predicted by the model.
|
104 |
+
std (torch.Tensor): Standard deviation values predicted by the model.
|
105 |
+
logmix (torch.Tensor): Logarithm of the mixture coefficients.
|
106 |
+
true (torch.Tensor): True target values.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: The computed loss value.
|
110 |
+
"""
|
111 |
log_prob = (
|
112 |
logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
|
113 |
)
|
|
|
115 |
loss = -log_prob.mean()
|
116 |
return loss
|
117 |
|
118 |
+
def _loss_function_da(self, f1: torch.Tensor, f2: torch.Tensor) -> torch.Tensor:
|
119 |
+
"""Compute the KL divergence loss for domain adaptation.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
f1 (torch.Tensor): Features from the primary domain.
|
123 |
+
f2 (torch.Tensor): Features from the domain for adaptation.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
torch.Tensor: The KL divergence loss value.
|
127 |
+
"""
|
128 |
+
|
129 |
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
130 |
loss = kl_loss(f1, f2)
|
131 |
return torch.log(loss)
|
132 |
|
133 |
+
def _to_numpy(self, x: torch.Tensor) -> np.ndarray:
|
134 |
+
"""Convert a tensor to a NumPy array.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
x (torch.Tensor): The input tensor to convert.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
np.ndarray: The converted NumPy array.
|
141 |
+
"""
|
142 |
return x.detach().cpu().numpy()
|
143 |
|
144 |
def train(
|
145 |
self,
|
146 |
+
input_data: np.ndarray,
|
147 |
+
input_data_da: np.ndarray,
|
148 |
+
target_data: np.ndarray,
|
149 |
+
nepochs: int = 10,
|
150 |
+
step_size: int = 100,
|
151 |
+
val_fraction: float = 0.1,
|
152 |
+
lr: float = 1e-3,
|
153 |
+
weight_decay: float = 0,
|
154 |
+
) -> None:
|
155 |
+
"""Train the models using provided data.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
input_data (np.ndarray): The input features for training.
|
159 |
+
input_data_da (np.ndarray): Input data for domain adaptation.
|
160 |
+
target_data (np.ndarray): The target outputs for training.
|
161 |
+
nepochs (int): Number of training epochs. Default is 10.
|
162 |
+
step_size (int): Step size for learning rate scheduling. Default is 100.
|
163 |
+
val_fraction (float): Fraction of data to use for validation. Default is 0.1.
|
164 |
+
lr (float): Learning rate for the optimizer. Default is 1e-3.
|
165 |
+
weight_decay (float): Weight decay for regularization. Default is 0.
|
166 |
+
"""
|
167 |
self.model_z.train()
|
168 |
self.model_f.train()
|
169 |
|
|
|
232 |
f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
|
233 |
)
|
234 |
|
235 |
+
def _validate(
|
236 |
+
self, loader_val: DataLoader, target_data: torch.Tensor
|
237 |
+
) -> List[float]:
|
238 |
"""Validate the model on the validation dataset."""
|
239 |
+
|
240 |
self.model_z.eval()
|
241 |
self.model_f.eval()
|
242 |
_loss_validation = []
|
|
|
258 |
|
259 |
return _loss_validation
|
260 |
|
261 |
+
def get_features(self, input_data: torch.Tensor) -> np.ndarray:
|
262 |
+
"""Extract features from the model for the given input data.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
input_data (torch.Tensor): Input tensor containing the data for which features are to be extracted.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
np.ndarray: Numpy array of extracted features from the model.
|
269 |
+
"""
|
270 |
+
|
271 |
self.model_f.eval()
|
272 |
input_data = input_data.to(self.device)
|
273 |
features = self.model_f(input_data)
|
274 |
return self._to_numpy(features)
|
275 |
|
276 |
+
def get_pz(
|
277 |
+
self,
|
278 |
+
input_data: torch.Tensor,
|
279 |
+
return_pz: bool = True,
|
280 |
+
return_flag: bool = True,
|
281 |
+
return_odds: bool = False,
|
282 |
+
) -> Union[
|
283 |
+
Tuple[np.ndarray, np.ndarray], # Return z and zerr
|
284 |
+
Tuple[np.ndarray, np.ndarray], # Return z, pz
|
285 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray] # Return z, pz, odds
|
286 |
+
]:
|
287 |
+
"""Get the predicted redshift (z) values and their uncertainties from the model.
|
288 |
+
|
289 |
+
This function predicts the photo-z for the input galaxies, computes the mean and standard
|
290 |
+
deviation for the predicted redshifts, and optionally calculates the probability density function (PDF).
|
291 |
+
|
292 |
+
Args:
|
293 |
+
input_data (torch.Tensor): Input tensor containing galaxy data for which to predict redshifts.
|
294 |
+
return_pz (bool, optional): Flag indicating whether to return the probability density function. Defaults to True.
|
295 |
+
return_flag (bool, optional): Flag indicating whether to return additional information. Defaults to True.
|
296 |
+
return_odds (bool, optional): Flag indicating whether to return the odds. Defaults to False.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
300 |
+
- If return_pz is True, returns the PDF and possibly additional metrics.
|
301 |
+
- If return_pz is False, returns a tuple containing the predicted redshifts and their uncertainties.
|
302 |
+
"""
|
303 |
+
|
304 |
logger.info("Predicting photo-z for the input galaxies...")
|
305 |
self.model_z.eval().to(self.device)
|
306 |
self.model_f.eval().to(self.device)
|
|
|
318 |
+ (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
|
319 |
)
|
320 |
|
321 |
+
z = self._to_numpy(z)
|
322 |
mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
|
323 |
|
324 |
if return_pz:
|
|
|
327 |
else:
|
328 |
return self._to_numpy(z), self._to_numpy(zerr)
|
329 |
|
330 |
+
def _calculate_pdf(
|
331 |
+
self,
|
332 |
+
z: np.ndarray,
|
333 |
+
mu: np.ndarray,
|
334 |
+
sig: np.ndarray,
|
335 |
+
mix_coeff: np.ndarray,
|
336 |
+
return_flag: bool,
|
337 |
+
) -> Union[
|
338 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]
|
339 |
+
]:
|
340 |
+
"""Calculate the probability density function (PDF) for the predicted redshifts.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
z (np.ndarray): Predicted redshift values.
|
344 |
+
mu (np.ndarray): Mean values for the Gaussian components.
|
345 |
+
sig (np.ndarray): Standard deviations for the Gaussian components.
|
346 |
+
mix_coeff (np.ndarray): Mixture coefficients for the Gaussian components.
|
347 |
+
return_flag (bool): Flag indicating whether to calculate and return odds.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
|
351 |
+
- If return_flag is True, returns a tuple containing the redshift values, PDF, and the z-grid.
|
352 |
+
- If return_flag is False, returns a tuple containing the redshift values and PDF.
|
353 |
+
"""
|
354 |
+
|
355 |
zgrid = np.linspace(0, 5, 1000)
|
356 |
pz = np.zeros((len(z), len(zgrid)))
|
357 |
|
358 |
for ii in range(len(z)):
|
359 |
for i in range(self.ngaussians):
|
360 |
+
pz[ii] += mix_coeff[ii, i] * norm.pdf(zgrid, mu[ii, i], sig[ii, i])
|
|
|
|
|
361 |
|
362 |
if return_flag:
|
363 |
logger.info("Calculating and returning ODDS")
|
364 |
pz /= pz.sum(axis=1, keepdims=True)
|
365 |
return self._calculate_odds(z, pz, zgrid)
|
366 |
+
return z, pz
|
367 |
+
|
368 |
+
def _calculate_odds(
|
369 |
+
self, z: np.ndarray, pz: np.ndarray, zgrid: np.ndarray
|
370 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
371 |
+
"""Calculate the odds for the estimated redshifts based on the cumulative distribution.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
z (np.ndarray): Predicted redshift values.
|
375 |
+
pz (np.ndarray): Probability density function values.
|
376 |
+
zgrid (np.ndarray): Grid of redshift values for evaluation.
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the predicted redshift values,
|
380 |
+
PDF values, and calculated odds.
|
381 |
+
"""
|
382 |
+
|
383 |
+
cumulative = np.cumsum(pz, axis=1)
|
384 |
+
odds = np.array(
|
385 |
+
[np.max(np.abs(cumulative[i] - 0.68)) for i in range(cumulative.shape[0])]
|
386 |
+
)
|
387 |
+
return z, pz, odds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temps/temps_arch.py
CHANGED
@@ -1,10 +1,24 @@
|
|
1 |
import torch
|
2 |
-
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
|
6 |
class EncoderPhotometry(nn.Module):
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
super(EncoderPhotometry, self).__init__()
|
9 |
|
10 |
self.features = nn.Sequential(
|
@@ -23,14 +37,39 @@ class EncoderPhotometry(nn.Module):
|
|
23 |
nn.Linear(20, 10),
|
24 |
)
|
25 |
|
26 |
-
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
f = self.features(x)
|
28 |
f = F.log_softmax(f, dim=1)
|
29 |
return f
|
30 |
|
31 |
|
32 |
class MeasureZ(nn.Module):
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
super(MeasureZ, self).__init__()
|
35 |
|
36 |
self.ngaussians = num_gauss
|
@@ -55,11 +94,25 @@ class MeasureZ(nn.Module):
|
|
55 |
nn.Linear(20, num_gauss),
|
56 |
)
|
57 |
|
58 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
mu = self.measure_mu(f)
|
60 |
sigma = self.measure_sigma(f)
|
61 |
logmix_coeff = self.measure_coeffs(f)
|
62 |
|
63 |
-
|
|
|
64 |
|
65 |
return mu, sigma, logmix_coeff
|
|
|
1 |
import torch
|
2 |
+
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
|
6 |
class EncoderPhotometry(nn.Module):
|
7 |
+
"""Encoder for photometric data.
|
8 |
+
|
9 |
+
This neural network encodes photometric features into a lower-dimensional representation.
|
10 |
+
|
11 |
+
Attributes:
|
12 |
+
features (nn.Sequential): A sequential container of layers used for encoding.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, input_dim: int = 6, dropout_prob: float = 0) -> None:
|
16 |
+
"""Initializes the EncoderPhotometry module.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
input_dim (int): Number of input features (default is 6).
|
20 |
+
dropout_prob (float): Probability of dropout (default is 0).
|
21 |
+
"""
|
22 |
super(EncoderPhotometry, self).__init__()
|
23 |
|
24 |
self.features = nn.Sequential(
|
|
|
37 |
nn.Linear(20, 10),
|
38 |
)
|
39 |
|
40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41 |
+
"""Forward pass through the encoder.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
torch.Tensor: Log softmax output of shape (batch_size, 10).
|
48 |
+
"""
|
49 |
f = self.features(x)
|
50 |
f = F.log_softmax(f, dim=1)
|
51 |
return f
|
52 |
|
53 |
|
54 |
class MeasureZ(nn.Module):
|
55 |
+
"""Model to measure redshift parameters.
|
56 |
+
|
57 |
+
This model estimates the parameters of a mixture of Gaussians used for measuring redshift.
|
58 |
+
|
59 |
+
Attributes:
|
60 |
+
ngaussians (int): Number of Gaussian components in the mixture.
|
61 |
+
measure_mu (nn.Sequential): Sequential model to measure the mean (mu).
|
62 |
+
measure_coeffs (nn.Sequential): Sequential model to measure the mixing coefficients.
|
63 |
+
measure_sigma (nn.Sequential): Sequential model to measure the standard deviation (sigma).
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, num_gauss: int = 10, dropout_prob: float = 0) -> None:
|
67 |
+
"""Initializes the MeasureZ module.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
num_gauss (int): Number of Gaussian components (default is 10).
|
71 |
+
dropout_prob (float): Probability of dropout (default is 0).
|
72 |
+
"""
|
73 |
super(MeasureZ, self).__init__()
|
74 |
|
75 |
self.ngaussians = num_gauss
|
|
|
94 |
nn.Linear(20, num_gauss),
|
95 |
)
|
96 |
|
97 |
+
def forward(
|
98 |
+
self, f: torch.Tensor
|
99 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
100 |
+
"""Forward pass to measure redshift parameters.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
f (torch.Tensor): Input tensor of shape (batch_size, 10).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
|
107 |
+
- mu (torch.Tensor): Mean parameters of shape (batch_size, num_gauss).
|
108 |
+
- sigma (torch.Tensor): Standard deviation parameters of shape (batch_size, num_gauss).
|
109 |
+
- logmix_coeff (torch.Tensor): Log mixing coefficients of shape (batch_size, num_gauss).
|
110 |
+
"""
|
111 |
mu = self.measure_mu(f)
|
112 |
sigma = self.measure_sigma(f)
|
113 |
logmix_coeff = self.measure_coeffs(f)
|
114 |
|
115 |
+
# Normalize logmix_coeff to get valid mixture coefficients
|
116 |
+
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, dim=1, keepdim=True)
|
117 |
|
118 |
return mu, sigma, logmix_coeff
|
temps/utils.py
CHANGED
@@ -4,21 +4,31 @@ import matplotlib.pyplot as plt
|
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
from loguru import logger
|
|
|
7 |
|
8 |
|
9 |
-
def
|
10 |
-
|
11 |
-
|
|
|
12 |
|
13 |
-
def nmad(data):
|
|
|
14 |
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
15 |
|
16 |
|
17 |
-
def sigma68(data):
|
|
|
18 |
return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
|
19 |
|
20 |
|
21 |
-
def maximum_mean_discrepancy(
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
"""
|
23 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
24 |
|
@@ -40,7 +50,13 @@ def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num
|
|
40 |
return mmd_loss
|
41 |
|
42 |
|
43 |
-
def compute_kernel(
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
"""
|
45 |
Compute the kernel matrix based on the chosen kernel type.
|
46 |
|
@@ -61,7 +77,7 @@ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
|
61 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
62 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
63 |
|
64 |
-
kernel_input = (x - y).pow(2).mean(2)
|
65 |
|
66 |
if kernel_type == "linear":
|
67 |
kernel_matrix = kernel_input
|
@@ -80,46 +96,62 @@ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
|
80 |
|
81 |
|
82 |
def select_cut(
|
83 |
-
df
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
if (completenss_lim is None)
|
87 |
-
raise
|
88 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
89 |
raise ValueError("Select only one cut at a time")
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
|
115 |
-
)
|
116 |
|
117 |
if completenss_lim is not None:
|
118 |
logger.info("Selecting cut based on completeness")
|
119 |
selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
|
120 |
|
121 |
elif nmad_lim is not None:
|
122 |
-
logger.info("Selecting cut based on
|
123 |
selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
|
124 |
|
125 |
elif outliers_lim is not None:
|
@@ -127,11 +159,104 @@ def select_cut(
|
|
127 |
selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
|
128 |
|
129 |
logger.info(
|
130 |
-
f"This cut provides completeness of {selected_cut['completeness']},
|
|
|
131 |
)
|
132 |
|
133 |
df_cut = df[(df.odds > selected_cut["flagcut"])]
|
134 |
-
|
|
|
135 |
return df_cut, selected_cut["flagcut"], dfcuts
|
136 |
else:
|
137 |
return selected_cut["flagcut"], dfcuts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
from loguru import logger
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
|
9 |
|
10 |
+
def calculate_eta(df: pd.DataFrame) -> float:
|
11 |
+
"""Calculate the percentage of outliers in the DataFrame based on zwerr column."""
|
12 |
+
return len(df[np.abs(df.zwerr) > 0.15]) / len(df) * 100
|
13 |
+
|
14 |
|
15 |
+
def nmad(data: Union[np.ndarray, pd.Series]) -> float:
|
16 |
+
"""Calculate the normalized median absolute deviation (NMAD) of the data."""
|
17 |
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
18 |
|
19 |
|
20 |
+
def sigma68(data: Union[np.ndarray, pd.Series]) -> float:
|
21 |
+
"""Calculate the sigma68 metric, a robust measure of dispersion."""
|
22 |
return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
|
23 |
|
24 |
|
25 |
+
def maximum_mean_discrepancy(
|
26 |
+
x: torch.Tensor,
|
27 |
+
y: torch.Tensor,
|
28 |
+
kernel_type: str = "rbf",
|
29 |
+
kernel_mul: float = 2.0,
|
30 |
+
kernel_num: int = 5,
|
31 |
+
) -> torch.Tensor:
|
32 |
"""
|
33 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
34 |
|
|
|
50 |
return mmd_loss
|
51 |
|
52 |
|
53 |
+
def compute_kernel(
|
54 |
+
x: torch.Tensor,
|
55 |
+
y: torch.Tensor,
|
56 |
+
kernel_type: str = "rbf",
|
57 |
+
kernel_mul: float = 2.0,
|
58 |
+
kernel_num: int = 5,
|
59 |
+
) -> torch.Tensor:
|
60 |
"""
|
61 |
Compute the kernel matrix based on the chosen kernel type.
|
62 |
|
|
|
77 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
78 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
79 |
|
80 |
+
kernel_input = (x - y).pow(2).mean(2)
|
81 |
|
82 |
if kernel_type == "linear":
|
83 |
kernel_matrix = kernel_input
|
|
|
96 |
|
97 |
|
98 |
def select_cut(
|
99 |
+
df: pd.DataFrame,
|
100 |
+
completenss_lim: Optional[float] = None,
|
101 |
+
nmad_lim: Optional[float] = None,
|
102 |
+
outliers_lim: Optional[float] = None,
|
103 |
+
return_df: bool = False,
|
104 |
+
) -> Union[Tuple[pd.DataFrame, float, pd.DataFrame], Tuple[float, pd.DataFrame]]:
|
105 |
+
"""
|
106 |
+
Selects a cut based on one of the provided limits (completeness, NMAD, or outliers).
|
107 |
+
|
108 |
+
Args:
|
109 |
+
- df: DataFrame, containing the data
|
110 |
+
- completenss_lim: float, optional limit on completeness
|
111 |
+
- nmad_lim: float, optional limit on NMAD
|
112 |
+
- outliers_lim: float, optional limit on outliers (eta)
|
113 |
+
- return_df: bool, whether to return the filtered DataFrame
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
- selected_cut: If return_df is False, returns the cut value and a DataFrame of cuts.
|
117 |
+
If return_df is True, returns the filtered DataFrame, cut value, and cuts DataFrame.
|
118 |
+
"""
|
119 |
|
120 |
+
if (completenss_lim is None) and (nmad_lim is None) and (outliers_lim is None):
|
121 |
+
raise ValueError("Select at least one cut")
|
122 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
123 |
raise ValueError("Select only one cut at a time")
|
124 |
|
125 |
+
bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
|
126 |
+
scatter, eta, cmptnss, nobj = [], [], [], []
|
127 |
+
|
128 |
+
for k in range(len(bin_edges) - 1):
|
129 |
+
edge_min = bin_edges[k]
|
130 |
+
edge_max = bin_edges[k + 1]
|
131 |
+
|
132 |
+
df_bin = df[(df.odds > edge_min)]
|
133 |
+
cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
|
134 |
+
scatter.append(nmad(df_bin.zwerr))
|
135 |
+
eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
|
136 |
+
nobj.append(len(df_bin))
|
137 |
+
|
138 |
+
dfcuts = pd.DataFrame(
|
139 |
+
data=np.c_[
|
140 |
+
np.round(bin_edges[:-1], 5),
|
141 |
+
np.round(nobj, 1),
|
142 |
+
np.round(cmptnss, 1),
|
143 |
+
np.round(scatter, 3),
|
144 |
+
np.round(eta, 2),
|
145 |
+
],
|
146 |
+
columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
|
147 |
+
)
|
|
|
|
|
148 |
|
149 |
if completenss_lim is not None:
|
150 |
logger.info("Selecting cut based on completeness")
|
151 |
selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
|
152 |
|
153 |
elif nmad_lim is not None:
|
154 |
+
logger.info("Selecting cut based on NMAD")
|
155 |
selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
|
156 |
|
157 |
elif outliers_lim is not None:
|
|
|
159 |
selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
|
160 |
|
161 |
logger.info(
|
162 |
+
f"This cut provides completeness of {selected_cut['completeness']}, "
|
163 |
+
f"nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
|
164 |
)
|
165 |
|
166 |
df_cut = df[(df.odds > selected_cut["flagcut"])]
|
167 |
+
|
168 |
+
if return_df:
|
169 |
return df_cut, selected_cut["flagcut"], dfcuts
|
170 |
else:
|
171 |
return selected_cut["flagcut"], dfcuts
|
172 |
+
|
173 |
+
def calculate_pit(model_f: nn.Module,
|
174 |
+
model_z: nn.Module,
|
175 |
+
input_data: Tensor,
|
176 |
+
target_data: Tensor,
|
177 |
+
) -> List[float]:
|
178 |
+
|
179 |
+
logger.info('Calculating PIT values')
|
180 |
+
|
181 |
+
pit_list = []
|
182 |
+
|
183 |
+
model_f = model_f.eval()
|
184 |
+
model_f = model_f.to(self.device)
|
185 |
+
model_z = model_z.eval()
|
186 |
+
model_z = model_z.to(self.device)
|
187 |
+
|
188 |
+
input_data = input_data.to(self.device)
|
189 |
+
|
190 |
+
|
191 |
+
features = model_f(input_data)
|
192 |
+
mu, logsig, logmix_coeff = model_z(features)
|
193 |
+
|
194 |
+
logsig = torch.clamp(logsig,-6,2)
|
195 |
+
sig = torch.exp(logsig)
|
196 |
+
|
197 |
+
mix_coeff = torch.exp(logmix_coeff)
|
198 |
+
|
199 |
+
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
200 |
+
|
201 |
+
for ii in range(len(input_data)):
|
202 |
+
pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
|
203 |
+
pit_list.append(pit)
|
204 |
+
|
205 |
+
|
206 |
+
return pit_list
|
207 |
+
|
208 |
+
def calculate_crps(model_f: nn.Module,
|
209 |
+
model_z: nn.Module,
|
210 |
+
input_data: Tensor,
|
211 |
+
target_data: Tensor,
|
212 |
+
) -> List[float]:
|
213 |
+
logger.info('Calculating CRPS values')
|
214 |
+
|
215 |
+
def measure_crps(cdf, t):
|
216 |
+
zgrid = np.linspace(0,4,1000)
|
217 |
+
Deltaz = zgrid[None,:] - t[:,None]
|
218 |
+
DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
|
219 |
+
integral = (cdf-DeltaZ_heaviside)**2
|
220 |
+
crps_value = integral.sum(1) / 1000
|
221 |
+
|
222 |
+
return crps_value
|
223 |
+
|
224 |
+
|
225 |
+
crps_list = []
|
226 |
+
|
227 |
+
model_f = model_f.eval()
|
228 |
+
model_f = model_f.to(self.device)
|
229 |
+
model_z = model_z.eval()
|
230 |
+
model_z = model_z.to(self.device)
|
231 |
+
|
232 |
+
input_data = input_data.to(self.device)
|
233 |
+
|
234 |
+
|
235 |
+
features = model_f(input_data)
|
236 |
+
mu, logsig, logmix_coeff = model_z(features)
|
237 |
+
logsig = torch.clamp(logsig,-6,2)
|
238 |
+
sig = torch.exp(logsig)
|
239 |
+
|
240 |
+
mix_coeff = torch.exp(logmix_coeff)
|
241 |
+
|
242 |
+
|
243 |
+
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
244 |
+
|
245 |
+
z = (mix_coeff * mu).sum(1)
|
246 |
+
|
247 |
+
x = np.linspace(0, 4, 1000)
|
248 |
+
pz = np.zeros(shape=(len(target_data), len(x)))
|
249 |
+
for ii in range(len(input_data)):
|
250 |
+
for i in range(6):
|
251 |
+
pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
252 |
+
|
253 |
+
pz = pz / pz.sum(1)[:,None]
|
254 |
+
|
255 |
+
|
256 |
+
cdf_z = np.cumsum(pz,1)
|
257 |
+
|
258 |
+
crps_value = measure_crps(cdf_z, target_data)
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
return crps_value
|