Spaces:
Runtime error
Runtime error
from dataclasses import dataclass, field | |
import numpy as np | |
import pandas as pd | |
from astropy.io import fits | |
from astropy.table import Table | |
from scipy.spatial import KDTree | |
from matplotlib import pyplot as plt | |
from matplotlib import rcParams | |
from pathlib import Path | |
from loguru import logger | |
from typing import Optional, Tuple, Union, List | |
# Set matplotlib configuration | |
rcParams["mathtext.fontset"] = "stix" | |
rcParams["font.family"] = "STIXGeneral" | |
class Archive: | |
path_calib: Path | |
path_valid: Optional[Path] = None | |
drop_stars: bool = True | |
clean_photometry: bool = True | |
convert_colors: bool = True | |
extinction_corr: bool = True | |
only_zspec: bool = True | |
columns_photometry: List[str] = field(default_factory=lambda: [ | |
"FLUX_G_2", | |
"FLUX_R_2", | |
"FLUX_I_2", | |
"FLUX_Z_2", | |
"FLUX_Y_2", | |
"FLUX_J_2", | |
"FLUX_H_2", | |
]) | |
columns_ebv: List[str] = field(default_factory=lambda: [ | |
"EB_V_corr_FLUX_G", | |
"EB_V_corr_FLUX_R", | |
"EB_V_corr_FLUX_I", | |
"EB_V_corr_FLUX_Z", | |
"EB_V_corr_FLUX_Y", | |
"EB_V_corr_FLUX_J", | |
"EB_V_corr_FLUX_H", | |
]) | |
photoz_name: str = "photo_z_L15" | |
specz_name: str = "z_spec_S15" | |
target_test: str = "specz" | |
flags_kept: List[float] = field(default_factory=lambda: [3, 3.1, 3.4, 3.5, 4]) | |
def __post_init__(self): | |
logger.info("Starting archive") | |
# Load data based on the file format | |
if self.path_calib.suffix == ".fits": | |
with fits.open(self.path_calib) as hdu_list: | |
self.cat = Table(hdu_list[1].data).to_pandas() | |
if self.path_valid is not None: | |
with fits.open(self.path_valid) as hdu_list: | |
self.cat_test = Table(hdu_list[1].data).to_pandas() | |
elif self.path_calib.suffix == ".csv": | |
self.cat = pd.read_csv(self.path_calib) | |
if self.path_valid is not None: | |
self.cat_test = pd.read_csv(self.path_valid) | |
else: | |
raise ValueError("Unsupported file format. Please provide a .fits or .csv file.") | |
self.cat = self.cat.rename( | |
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"} | |
) | |
self.cat_test = self.cat_test.rename( | |
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"} | |
) | |
self.cat = self.cat[(self.cat["specz"] > 0) | (self.cat["photo_z"] > 0)] | |
# Apply operations based on the initialized parameters | |
if self.drop_stars: | |
logger.info("Dropping stars...") | |
self.cat = self.cat[self.cat.mu_class_L07 == 1] | |
self.cat_test = self.cat_test[self.cat_test.mu_class_L07 == 1] | |
if self.clean_photometry: | |
logger.info("Cleaning photometry...") | |
self.cat = self._clean_photometry(catalogue=self.cat) | |
self.cat_test = self._clean_photometry(catalogue=self.cat_test) | |
self.cat = self._set_combined_target(self.cat) | |
self.cat_test = self._set_combined_target(self.cat_test) | |
# Apply magnitude and redshift cuts | |
self.cat = self.cat[self.cat.MAG_VIS < 25] | |
self.cat_test = self.cat_test[self.cat_test.MAG_VIS < 25] | |
self.cat = self.cat[self.cat.target_z < 5] | |
self.cat_test = self.cat_test[self.cat_test.target_z < 5] | |
self._set_training_data( | |
self.cat, | |
self.cat_test, | |
only_zspec=self.only_zspec, | |
extinction_corr=self.extinction_corr, | |
convert_colors=self.convert_colors, | |
) | |
self._set_testing_data( | |
self.cat_test, | |
target=self.target_test, | |
extinction_corr=self.extinction_corr, | |
convert_colors=self.convert_colors, | |
) | |
def _extract_fluxes(self, catalogue: pd.DataFrame) -> np.ndarray: | |
"""Extract fluxes from the given catalogue. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
Returns: | |
np.ndarray: An array of fluxes. | |
""" | |
f = catalogue[self.columns_photometry].values | |
return f | |
def _to_colors(flux: np.ndarray) -> np.ndarray: | |
"""Convert fluxes to colors. | |
Args: | |
flux (np.ndarray): The input fluxes. | |
Returns: | |
np.ndarray: An array of colors. | |
""" | |
color = flux[:, :-1] / flux[:, 1:] | |
return color | |
def _set_combined_target(catalogue: pd.DataFrame) -> pd.DataFrame: | |
"""Set the combined target redshift based on available data. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
Returns: | |
pd.DataFrame: Updated catalogue with the combined target redshift. | |
""" | |
catalogue["target_z"] = catalogue.apply( | |
lambda row: row["specz"] if row["specz"] > 0 else row["photo_z"], axis=1 | |
) | |
return catalogue | |
def _clean_photometry(catalogue: pd.DataFrame) -> pd.DataFrame: | |
"""Drops all objects with FLAG_PHOT != 0. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
Returns: | |
pd.DataFrame: Cleaned catalogue. | |
""" | |
catalogue = catalogue[catalogue["FLAG_PHOT"] == 0] | |
return catalogue | |
def _correct_extinction( | |
self, catalogue: pd.DataFrame, f: np.ndarray, return_ext_corr: bool = False | |
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: | |
"""Corrects for extinction based on the provided catalogue. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
f (np.ndarray): The flux values to correct. | |
return_ext_corr (bool): Whether to return the extinction correction values. | |
Returns: | |
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Corrected fluxes, and optionally the extinction corrections. | |
""" | |
ext_correction = catalogue[self.columns_ebv].values | |
f = f * ext_correction | |
if return_ext_corr: | |
return f, ext_correction | |
else: | |
return f | |
def _select_only_zspec( | |
catalogue: pd.DataFrame, cat_flag: Optional[str] = None | |
) -> pd.DataFrame: | |
"""Selects only galaxies with spectroscopic redshift. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
cat_flag (Optional[str]): Indicates the catalogue type ('Calib' or 'Valid'). | |
Returns: | |
pd.DataFrame: Filtered catalogue. | |
""" | |
if cat_flag == "Calib": | |
catalogue = catalogue[catalogue.specz > 0] | |
elif cat_flag == "Valid": | |
catalogue = catalogue[catalogue.specz > 0] | |
return catalogue | |
def take_zspec_and_photoz(catalogue: pd.DataFrame, cat_flag: Optional[str] = None | |
) -> pd.DataFrame: | |
"""Selects only galaxies with spectroscopic redshift""" | |
if cat_flag=='Calib': | |
catalogue = catalogue[catalogue.target_z>0] | |
elif cat_flag=='Valid': | |
catalogue = catalogue[catalogue.specz>0] | |
return catalogue | |
def exclude_only_zspec(catalogue: pd.DataFrame) -> pd.DataFrame: | |
"""Selects only galaxies without spectroscopic redshift. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
Returns: | |
pd.DataFrame: Filtered catalogue. | |
""" | |
catalogue = catalogue[ | |
(catalogue.specz < 0) & (catalogue.photo_z > 0) & (catalogue.photo_z < 4) | |
] | |
return catalogue | |
def _clean_zspec_sample(catalogue ,flags_kept=[3,3.1,3.4,3.5,4]): | |
catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)] | |
return catalogue | |
def _select_L15_sample(self, catalogue: pd.DataFrame) -> pd.DataFrame: | |
"""Selects only galaxies within a specific redshift range. | |
Args: | |
catalogue (pd.DataFrame): The input catalogue. | |
Returns: | |
pd.DataFrame: Filtered catalogue. | |
""" | |
catalogue = catalogue[(catalogue.target_z > 0) & (catalogue.target_z < 3)] | |
return catalogue | |
def _set_training_data(self, | |
catalogue: pd.DataFrame, | |
catalogue_da: pd.DataFrame, | |
only_zspec: bool = True, | |
extinction_corr: bool = True, | |
convert_colors: bool = True | |
)-> None: | |
cat_da = Archive.exclude_only_zspec(catalogue_da) | |
target_z_train_DA = cat_da['photo_z'].values | |
if only_zspec: | |
logger.info("Selecting only galaxies with spectroscopic redshift") | |
catalogue = Archive._select_only_zspec(catalogue, cat_flag='Calib') | |
catalogue = Archive._clean_zspec_sample(catalogue, flags_kept=self.flags_kept) | |
else: | |
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z") | |
catalogue = Archive.take_zspec_and_photoz(catalogue, cat_flag='Calib') | |
self.cat_train=catalogue | |
f = self._extract_fluxes(catalogue) | |
f_DA = self._extract_fluxes(cat_da) | |
idx = np.random.randint(0, len(f_DA), len(f)) | |
f_DA = f_DA[idx] | |
target_z_train_DA = target_z_train_DA[idx] | |
self.target_z_train_DA = target_z_train_DA | |
if extinction_corr==True: | |
logger.info("Correcting MW extinction") | |
f = self._correct_extinction(catalogue,f) | |
if convert_colors==True: | |
logger.info("Converting to colors") | |
col = self._to_colors(f) | |
col_DA = self._to_colors(f_DA) | |
self.phot_train = col | |
self.phot_train_DA = col_DA | |
else: | |
self.phot_train = f | |
self.phot_train_DA = f_DA | |
if only_zspec==True: | |
self.target_z_train = catalogue['specz'].values | |
else: | |
self.target_z_train = catalogue['target_z'].values | |
self.VIS_mag_train = catalogue['MAG_VIS'].values | |
def _set_testing_data( | |
self, | |
cat_test: pd.DataFrame, | |
target: str = "specz", | |
extinction_corr: bool = True, | |
convert_colors: bool = True, | |
) -> None: | |
if target=='specz': | |
cat_test = Archive._select_only_zspec(cat_test, cat_flag='Valid') | |
cat_test = Archive._clean_zspec_sample(cat_test) | |
self.target_z_test = cat_test['specz'].values | |
elif target=='L15': | |
cat_test = self._select_L15_sample(cat_test) | |
self.target_z_test = cat_test['target_z'].values | |
self.cat_test=cat_test | |
f = self._extract_fluxes(cat_test) | |
if extinction_corr==True: | |
f = self._correct_extinction(cat_test,f) | |
if convert_colors==True: | |
col = self._to_colors(f) | |
self.phot_test = col | |
else: | |
self.phot_test = f | |
self.VIS_mag_test = cat_test['MAG_VIS'].values | |
def get_training_data(self): | |
return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA | |
def get_testing_data(self): | |
return self.phot_test, self.target_z_test, self.VIS_mag_test | |
def get_VIS_mag(self, catalogue): | |
return catalogue[['MAG_VIS']].values |