TEMPS / temps /archive.py
lauracabayol's picture
improve code with dataclass and typing
c9354dd
raw
history blame
11.7 kB
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.table import Table
from scipy.spatial import KDTree
from matplotlib import pyplot as plt
from matplotlib import rcParams
from pathlib import Path
from loguru import logger
from typing import Optional, Tuple, Union, List
# Set matplotlib configuration
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
@dataclass
class Archive:
path_calib: Path
path_valid: Optional[Path] = None
drop_stars: bool = True
clean_photometry: bool = True
convert_colors: bool = True
extinction_corr: bool = True
only_zspec: bool = True
columns_photometry: List[str] = field(default_factory=lambda: [
"FLUX_G_2",
"FLUX_R_2",
"FLUX_I_2",
"FLUX_Z_2",
"FLUX_Y_2",
"FLUX_J_2",
"FLUX_H_2",
])
columns_ebv: List[str] = field(default_factory=lambda: [
"EB_V_corr_FLUX_G",
"EB_V_corr_FLUX_R",
"EB_V_corr_FLUX_I",
"EB_V_corr_FLUX_Z",
"EB_V_corr_FLUX_Y",
"EB_V_corr_FLUX_J",
"EB_V_corr_FLUX_H",
])
photoz_name: str = "photo_z_L15"
specz_name: str = "z_spec_S15"
target_test: str = "specz"
flags_kept: List[float] = field(default_factory=lambda: [3, 3.1, 3.4, 3.5, 4])
def __post_init__(self):
logger.info("Starting archive")
# Load data based on the file format
if self.path_calib.suffix == ".fits":
with fits.open(self.path_calib) as hdu_list:
self.cat = Table(hdu_list[1].data).to_pandas()
if self.path_valid is not None:
with fits.open(self.path_valid) as hdu_list:
self.cat_test = Table(hdu_list[1].data).to_pandas()
elif self.path_calib.suffix == ".csv":
self.cat = pd.read_csv(self.path_calib)
if self.path_valid is not None:
self.cat_test = pd.read_csv(self.path_valid)
else:
raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
self.cat = self.cat.rename(
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
)
self.cat_test = self.cat_test.rename(
columns={f"{self.specz_name}": "specz", f"{self.photoz_name}": "photo_z"}
)
self.cat = self.cat[(self.cat["specz"] > 0) | (self.cat["photo_z"] > 0)]
# Apply operations based on the initialized parameters
if self.drop_stars:
logger.info("Dropping stars...")
self.cat = self.cat[self.cat.mu_class_L07 == 1]
self.cat_test = self.cat_test[self.cat_test.mu_class_L07 == 1]
if self.clean_photometry:
logger.info("Cleaning photometry...")
self.cat = self._clean_photometry(catalogue=self.cat)
self.cat_test = self._clean_photometry(catalogue=self.cat_test)
self.cat = self._set_combined_target(self.cat)
self.cat_test = self._set_combined_target(self.cat_test)
# Apply magnitude and redshift cuts
self.cat = self.cat[self.cat.MAG_VIS < 25]
self.cat_test = self.cat_test[self.cat_test.MAG_VIS < 25]
self.cat = self.cat[self.cat.target_z < 5]
self.cat_test = self.cat_test[self.cat_test.target_z < 5]
self._set_training_data(
self.cat,
self.cat_test,
only_zspec=self.only_zspec,
extinction_corr=self.extinction_corr,
convert_colors=self.convert_colors,
)
self._set_testing_data(
self.cat_test,
target=self.target_test,
extinction_corr=self.extinction_corr,
convert_colors=self.convert_colors,
)
def _extract_fluxes(self, catalogue: pd.DataFrame) -> np.ndarray:
"""Extract fluxes from the given catalogue.
Args:
catalogue (pd.DataFrame): The input catalogue.
Returns:
np.ndarray: An array of fluxes.
"""
f = catalogue[self.columns_photometry].values
return f
@staticmethod
def _to_colors(flux: np.ndarray) -> np.ndarray:
"""Convert fluxes to colors.
Args:
flux (np.ndarray): The input fluxes.
Returns:
np.ndarray: An array of colors.
"""
color = flux[:, :-1] / flux[:, 1:]
return color
@staticmethod
def _set_combined_target(catalogue: pd.DataFrame) -> pd.DataFrame:
"""Set the combined target redshift based on available data.
Args:
catalogue (pd.DataFrame): The input catalogue.
Returns:
pd.DataFrame: Updated catalogue with the combined target redshift.
"""
catalogue["target_z"] = catalogue.apply(
lambda row: row["specz"] if row["specz"] > 0 else row["photo_z"], axis=1
)
return catalogue
@staticmethod
def _clean_photometry(catalogue: pd.DataFrame) -> pd.DataFrame:
"""Drops all objects with FLAG_PHOT != 0.
Args:
catalogue (pd.DataFrame): The input catalogue.
Returns:
pd.DataFrame: Cleaned catalogue.
"""
catalogue = catalogue[catalogue["FLAG_PHOT"] == 0]
return catalogue
def _correct_extinction(
self, catalogue: pd.DataFrame, f: np.ndarray, return_ext_corr: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Corrects for extinction based on the provided catalogue.
Args:
catalogue (pd.DataFrame): The input catalogue.
f (np.ndarray): The flux values to correct.
return_ext_corr (bool): Whether to return the extinction correction values.
Returns:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Corrected fluxes, and optionally the extinction corrections.
"""
ext_correction = catalogue[self.columns_ebv].values
f = f * ext_correction
if return_ext_corr:
return f, ext_correction
else:
return f
@staticmethod
def _select_only_zspec(
catalogue: pd.DataFrame, cat_flag: Optional[str] = None
) -> pd.DataFrame:
"""Selects only galaxies with spectroscopic redshift.
Args:
catalogue (pd.DataFrame): The input catalogue.
cat_flag (Optional[str]): Indicates the catalogue type ('Calib' or 'Valid').
Returns:
pd.DataFrame: Filtered catalogue.
"""
if cat_flag == "Calib":
catalogue = catalogue[catalogue.specz > 0]
elif cat_flag == "Valid":
catalogue = catalogue[catalogue.specz > 0]
return catalogue
@staticmethod
def take_zspec_and_photoz(catalogue: pd.DataFrame, cat_flag: Optional[str] = None
) -> pd.DataFrame:
"""Selects only galaxies with spectroscopic redshift"""
if cat_flag=='Calib':
catalogue = catalogue[catalogue.target_z>0]
elif cat_flag=='Valid':
catalogue = catalogue[catalogue.specz>0]
return catalogue
@staticmethod
def exclude_only_zspec(catalogue: pd.DataFrame) -> pd.DataFrame:
"""Selects only galaxies without spectroscopic redshift.
Args:
catalogue (pd.DataFrame): The input catalogue.
Returns:
pd.DataFrame: Filtered catalogue.
"""
catalogue = catalogue[
(catalogue.specz < 0) & (catalogue.photo_z > 0) & (catalogue.photo_z < 4)
]
return catalogue
@staticmethod
def _clean_zspec_sample(catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
catalogue = catalogue[catalogue.Q_f_S15.isin(flags_kept)]
return catalogue
@staticmethod
def _select_L15_sample(self, catalogue: pd.DataFrame) -> pd.DataFrame:
"""Selects only galaxies within a specific redshift range.
Args:
catalogue (pd.DataFrame): The input catalogue.
Returns:
pd.DataFrame: Filtered catalogue.
"""
catalogue = catalogue[(catalogue.target_z > 0) & (catalogue.target_z < 3)]
return catalogue
def _set_training_data(self,
catalogue: pd.DataFrame,
catalogue_da: pd.DataFrame,
only_zspec: bool = True,
extinction_corr: bool = True,
convert_colors: bool = True
)-> None:
cat_da = Archive.exclude_only_zspec(catalogue_da)
target_z_train_DA = cat_da['photo_z'].values
if only_zspec:
logger.info("Selecting only galaxies with spectroscopic redshift")
catalogue = Archive._select_only_zspec(catalogue, cat_flag='Calib')
catalogue = Archive._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
else:
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
catalogue = Archive.take_zspec_and_photoz(catalogue, cat_flag='Calib')
self.cat_train=catalogue
f = self._extract_fluxes(catalogue)
f_DA = self._extract_fluxes(cat_da)
idx = np.random.randint(0, len(f_DA), len(f))
f_DA = f_DA[idx]
target_z_train_DA = target_z_train_DA[idx]
self.target_z_train_DA = target_z_train_DA
if extinction_corr==True:
logger.info("Correcting MW extinction")
f = self._correct_extinction(catalogue,f)
if convert_colors==True:
logger.info("Converting to colors")
col = self._to_colors(f)
col_DA = self._to_colors(f_DA)
self.phot_train = col
self.phot_train_DA = col_DA
else:
self.phot_train = f
self.phot_train_DA = f_DA
if only_zspec==True:
self.target_z_train = catalogue['specz'].values
else:
self.target_z_train = catalogue['target_z'].values
self.VIS_mag_train = catalogue['MAG_VIS'].values
def _set_testing_data(
self,
cat_test: pd.DataFrame,
target: str = "specz",
extinction_corr: bool = True,
convert_colors: bool = True,
) -> None:
if target=='specz':
cat_test = Archive._select_only_zspec(cat_test, cat_flag='Valid')
cat_test = Archive._clean_zspec_sample(cat_test)
self.target_z_test = cat_test['specz'].values
elif target=='L15':
cat_test = self._select_L15_sample(cat_test)
self.target_z_test = cat_test['target_z'].values
self.cat_test=cat_test
f = self._extract_fluxes(cat_test)
if extinction_corr==True:
f = self._correct_extinction(cat_test,f)
if convert_colors==True:
col = self._to_colors(f)
self.phot_test = col
else:
self.phot_test = f
self.VIS_mag_test = cat_test['MAG_VIS'].values
def get_training_data(self):
return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
def get_testing_data(self):
return self.phot_test, self.target_z_test, self.VIS_mag_test
def get_VIS_mag(self, catalogue):
return catalogue[['MAG_VIS']].values