File size: 11,723 Bytes
c9354dd
c212435
 
 
a57776c
c212435
57fa8fc
 
c9354dd
57fa8fc
c9354dd
c212435
c9354dd
c212435
 
 
c9354dd
57fa8fc
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57fa8fc
c9354dd
57fa8fc
a57776c
c9354dd
 
 
 
 
 
 
 
 
 
 
 
a57776c
 
 
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57776c
 
692f707
c9354dd
 
 
 
 
 
 
 
 
 
 
a57776c
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
fc92339
 
c9354dd
 
 
 
 
 
 
 
 
 
 
c212435
c9354dd
 
 
 
 
 
 
 
 
 
 
 
 
 
a57776c
c212435
692f707
 
 
 
d307831
c9354dd
 
 
 
 
d307831
c9354dd
 
 
 
 
 
 
 
 
 
 
d307831
c9354dd
 
 
 
fc92339
 
 
 
a57776c
fc92339
c9354dd
 
 
 
fc92339
c9354dd
 
c212435
c9354dd
 
 
 
 
 
696a020
fc92339
c9354dd
 
 
 
c212435
c9354dd
 
 
c212435
c9354dd
 
c212435
c9354dd
 
 
 
 
fc92339
c9354dd
 
 
 
 
 
 
 
 
a57776c
d307831
c212435
 
57fa8fc
c9354dd
 
fc92339
57fa8fc
c9354dd
fc92339
c212435
 
a57776c
 
d307831
a57776c
d307831
 
c212435
 
 
57fa8fc
c212435
692f707
c212435
57fa8fc
a57776c
 
d307831
c212435
d307831
c212435
 
d307831
c212435
fc92339
a57776c
fc92339
 
 
 
c9354dd
 
 
 
 
 
 
 
 
d307831
 
c9354dd
 
 
d307831
 
c9354dd
 
adcaa7a
d307831
c9354dd
c212435
c9354dd
c212435
 
c9354dd
c212435
 
a57776c
c212435
 
 
 
d307831
c9354dd
 
 
c212435
a57776c
c212435
 
a57776c
c212435
 
c9354dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.table import Table
from scipy.spatial import KDTree
from matplotlib import pyplot as plt
from matplotlib import rcParams
from pathlib import Path
from loguru import logger
from typing import Optional, Tuple, Union, List

# Set matplotlib configuration
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"

@dataclass
class Archive:
    path_calib: Path
    path_valid: Optional[Path] = None
    drop_stars: bool = True
    clean_photometry: bool = True
    convert_colors: bool = True
    extinction_corr: bool = True
    only_zspec: bool = True
    columns_photometry: List[str] = field(default_factory=lambda: [
        "FLUX_G_2",
        "FLUX_R_2",
        "FLUX_I_2",
        "FLUX_Z_2",
        "FLUX_Y_2",
        "FLUX_J_2",
        "FLUX_H_2",
    ])
    columns_ebv: List[str] = field(default_factory=lambda: [
        "EB_V_corr_FLUX_G",
        "EB_V_corr_FLUX_R",
        "EB_V_corr_FLUX_I",
        "EB_V_corr_FLUX_Z",
        "EB_V_corr_FLUX_Y",
        "EB_V_corr_FLUX_J",
        "EB_V_corr_FLUX_H",
    ])
    photoz_name: str = "photo_z_L15"
    specz_name: str = "z_spec_S15"
    target_test: str = "specz"
    flags_kept: List[float] = field(default_factory=lambda: [3, 3.1, 3.4, 3.5, 4])

    def __post_init__(self):
        logger.info("Starting archive")

        # Load data based on the file format
        if self.path_calib.suffix == ".fits":
            with fits.open(self.path_calib) as hdu_list:
                self.cat = Table(hdu_list[1].data).to_pandas()
            if self.path_valid is not None:
                with fits.open(self.path_valid) as hdu_list:
                    self.cat_test = Table(hdu_list[1].data).to_pandas()

        elif self.path_calib.suffix == ".csv":
            self.cat = pd.read_csv(self.path_calib)
            if self.path_valid is not None:
                self.cat_test = pd.read_csv(self.path_valid)
        else:
            raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")

        self.cat = self.cat.rename(
            columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
        )
        self.cat_test = self.cat_test.rename(
            columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
        )

        self.cat = self.cat[(self.cat["specz"] > 0) | (self.cat["photo_z"] > 0)]

        # Apply operations based on the initialized parameters
        if self.drop_stars:
            logger.info("Dropping stars...")
            self.cat = self.cat[self.cat.mu_class_L07 == 1]
            self.cat_test = self.cat_test[self.cat_test.mu_class_L07 == 1]

        if self.clean_photometry:
            logger.info("Cleaning photometry...")
            self.cat = self._clean_photometry(catalogue=self.cat)
            self.cat_test = self._clean_photometry(catalogue=self.cat_test)

        self.cat = self._set_combined_target(self.cat)
        self.cat_test = self._set_combined_target(self.cat_test)

        # Apply magnitude and redshift cuts
        self.cat = self.cat[self.cat.MAG_VIS < 25]
        self.cat_test = self.cat_test[self.cat_test.MAG_VIS < 25]

        self.cat = self.cat[self.cat.target_z < 5]
        self.cat_test = self.cat_test[self.cat_test.target_z < 5]

        self._set_training_data(
            self.cat,
            self.cat_test,
            only_zspec=self.only_zspec,
            extinction_corr=self.extinction_corr,
            convert_colors=self.convert_colors,
        )
        self._set_testing_data(
            self.cat_test,
            target=self.target_test,
            extinction_corr=self.extinction_corr,
            convert_colors=self.convert_colors,
        )


    def _extract_fluxes(self, catalogue: pd.DataFrame) -> np.ndarray:
        """Extract fluxes from the given catalogue.

        Args:
            catalogue (pd.DataFrame): The input catalogue.

        Returns:
            np.ndarray: An array of fluxes.
        """
        f = catalogue[self.columns_photometry].values
        return f

    @staticmethod
    def _to_colors(flux: np.ndarray) -> np.ndarray:
        """Convert fluxes to colors.

        Args:
            flux (np.ndarray): The input fluxes.

        Returns:
            np.ndarray: An array of colors.
        """
        color = flux[:, :-1] / flux[:, 1:]
        return color

    @staticmethod
    def _set_combined_target(catalogue: pd.DataFrame) -> pd.DataFrame:
        """Set the combined target redshift based on available data.

        Args:
            catalogue (pd.DataFrame): The input catalogue.

        Returns:
            pd.DataFrame: Updated catalogue with the combined target redshift.
        """
        catalogue["target_z"] = catalogue.apply(
            lambda row: row["specz"] if row["specz"] > 0 else row["photo_z"], axis=1
        )
        return catalogue

    @staticmethod
    def _clean_photometry(catalogue: pd.DataFrame) -> pd.DataFrame:
        """Drops all objects with FLAG_PHOT != 0.

        Args:
            catalogue (pd.DataFrame): The input catalogue.

        Returns:
            pd.DataFrame: Cleaned catalogue.
        """
        catalogue = catalogue[catalogue["FLAG_PHOT"] == 0]
        return catalogue

    def _correct_extinction(
        self, catalogue: pd.DataFrame, f: np.ndarray, return_ext_corr: bool = False
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Corrects for extinction based on the provided catalogue.

        Args:
            catalogue (pd.DataFrame): The input catalogue.
            f (np.ndarray): The flux values to correct.
            return_ext_corr (bool): Whether to return the extinction correction values.

        Returns:
            Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Corrected fluxes, and optionally the extinction corrections.
        """
        ext_correction = catalogue[self.columns_ebv].values
        f = f * ext_correction
        if return_ext_corr:
            return f, ext_correction
        else:
            return f

    @staticmethod
    def _select_only_zspec(
        catalogue: pd.DataFrame, cat_flag: Optional[str] = None
    ) -> pd.DataFrame:
        """Selects only galaxies with spectroscopic redshift.

        Args:
            catalogue (pd.DataFrame): The input catalogue.
            cat_flag (Optional[str]): Indicates the catalogue type ('Calib' or 'Valid').

        Returns:
            pd.DataFrame: Filtered catalogue.
        """
        if cat_flag == "Calib":
            catalogue = catalogue[catalogue.specz > 0]
        elif cat_flag == "Valid":
            catalogue = catalogue[catalogue.specz > 0]
        return catalogue

    @staticmethod
    def take_zspec_and_photoz(catalogue: pd.DataFrame, cat_flag: Optional[str] = None
    ) -> pd.DataFrame:
        """Selects only galaxies with spectroscopic redshift"""
        if cat_flag=='Calib':
            catalogue = catalogue[catalogue.target_z>0]
        elif cat_flag=='Valid':
            catalogue = catalogue[catalogue.specz>0]
        return catalogue
    
    @staticmethod
    def exclude_only_zspec(catalogue: pd.DataFrame) -> pd.DataFrame:
        """Selects only galaxies without spectroscopic redshift.

        Args:
            catalogue (pd.DataFrame): The input catalogue.

        Returns:
            pd.DataFrame: Filtered catalogue.
        """
        catalogue = catalogue[
            (catalogue.specz < 0) & (catalogue.photo_z > 0) & (catalogue.photo_z < 4)
        ]
        return catalogue

    @staticmethod
    def _clean_zspec_sample(catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
        catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
        return catalogue

    @staticmethod
    def _select_L15_sample(self, catalogue: pd.DataFrame) -> pd.DataFrame:
        """Selects only galaxies within a specific redshift range.

        Args:
            catalogue (pd.DataFrame): The input catalogue.

        Returns:
            pd.DataFrame: Filtered catalogue.
        """
        catalogue = catalogue[(catalogue.target_z > 0) & (catalogue.target_z < 3)]
        return catalogue
        
    def _set_training_data(self,
                           catalogue: pd.DataFrame, 
                           catalogue_da: pd.DataFrame, 
                           only_zspec: bool = True,
                           extinction_corr: bool = True,
                           convert_colors: bool = True
                          )-> None:
                        
        cat_da = Archive.exclude_only_zspec(catalogue_da)  
        target_z_train_DA = cat_da['photo_z'].values
  
        
        if only_zspec:
            logger.info("Selecting only galaxies with spectroscopic redshift")
            catalogue = Archive._select_only_zspec(catalogue, cat_flag='Calib')
            catalogue = Archive._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
        else:
            logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
            catalogue = Archive.take_zspec_and_photoz(catalogue, cat_flag='Calib')
            
            
        self.cat_train=catalogue
        f = self._extract_fluxes(catalogue)
        f_DA = self._extract_fluxes(cat_da)
        idx = np.random.randint(0, len(f_DA), len(f))
        f_DA = f_DA[idx]
        target_z_train_DA = target_z_train_DA[idx]
        self.target_z_train_DA = target_z_train_DA
        
        
        if extinction_corr==True:
            logger.info("Correcting MW extinction")
            f = self._correct_extinction(catalogue,f)
                                
        if convert_colors==True:
            logger.info("Converting to colors")
            col = self._to_colors(f)
            col_DA = self._to_colors(f_DA)
            
            self.phot_train = col
            self.phot_train_DA = col_DA
        else:
            self.phot_train = f
            self.phot_train_DA = f_DA
            
        if only_zspec==True:
            self.target_z_train = catalogue['specz'].values
        else:
            self.target_z_train = catalogue['target_z'].values
            
        self.VIS_mag_train = catalogue['MAG_VIS'].values
    

    def _set_testing_data(
        self,
        cat_test: pd.DataFrame,
        target: str = "specz",
        extinction_corr: bool = True,
        convert_colors: bool = True,
    ) -> None:
        
        if target=='specz':
            cat_test = Archive._select_only_zspec(cat_test, cat_flag='Valid')
            cat_test = Archive._clean_zspec_sample(cat_test)
            self.target_z_test = cat_test['specz'].values
            
        elif target=='L15':
            cat_test = self._select_L15_sample(cat_test)
            self.target_z_test = cat_test['target_z'].values
                    
                        
        self.cat_test=cat_test
            
        f = self._extract_fluxes(cat_test)
        
        if extinction_corr==True:
            f = self._correct_extinction(cat_test,f)
            
        if convert_colors==True:
            col = self._to_colors(f)
            self.phot_test = col
        else:
            self.phot_test = f
            
        
        self.VIS_mag_test = cat_test['MAG_VIS'].values
    
    
    def get_training_data(self):
        return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA

    def get_testing_data(self):
        return self.phot_test, self.target_z_test, self.VIS_mag_test

    def get_VIS_mag(self, catalogue):
        return catalogue[['MAG_VIS']].values