TEMPS / insight /archive.py
lauracabayol's picture
optimized version working at low z
696a020
raw
history blame
7.41 kB
import numpy as np
import pandas as pd
from astropy.io import fits
import os
from astropy.table import Table
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
class archive():
def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, Qz_cut=1):
self.aperture = aperture
self.weight_dict={(-99,0.99):0,
(1,1.99):0.5,
(2,2.99):0.75,
(3,4):1,
(9,9.99):0.25,
(10,10.99):0,
(11,11.99):0.5,
(12,12.99):0.75,
(13,14):1,
(14.01,40):0
}
filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
filename_gold='Export_Gold_2023_07_03.csv'
hdu_list = fits.open(os.path.join(path,filename_calib))
cat = Table(hdu_list[1].data).to_pandas()
hdu_list = fits.open(os.path.join(path,filename_valid))
cat_test = Table(hdu_list[1].data).to_pandas()
self._get_loss_weights(cat)
self._get_loss_weights(cat_test)
gold_sample = pd.read_csv(os.path.join(path,filename_gold))
#cat_test = self._match_gold_sample(cat_test,gold_sample)
if drop_stars==True:
cat = cat[cat.mu_class_L07==1]
cat_test = cat_test[cat_test.mu_class_L07==1]
if clean_photometry==True:
cat = self._clean_photometry(cat)
cat_test = self._clean_photometry(cat_test)
cat = cat[cat.w_Q_f_S15>0]
self._set_training_data(cat, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors,Qz_cut=Qz_cut)
self._set_testing_data(cat_test, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
self._get_loss_weights(cat)
def _extract_fluxes(self,catalogue):
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
f = catalogue[columns_f].values
ferr = catalogue[columns_ferr].values
return f, ferr
def _to_colors(self, flux, fluxerr):
""" Convert fluxes to colors"""
color = flux[:,:-1] / flux[:,1:]
color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
return color,color_err
def _clean_photometry(self,catalogue):
""" Drops all object with FLAG_PHOT!=0"""
catalogue = catalogue[catalogue['FLAG_PHOT']==0]
return catalogue
def _correct_extinction(self,catalogue, f):
"""Corrects for extinction"""
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
ext_correction = catalogue[ext_correction_cols].values
f = f * ext_correction
return f
def _take_only_zspec(self,catalogue,cat_flag=None):
"""Selects only galaxies with spectroscopic redshift"""
if cat_flag=='Calib':
catalogue = catalogue[catalogue.z_spec_S15>0]
elif cat_flag=='Valid':
catalogue = catalogue[catalogue.z_spec_S15>0]
return catalogue
def _clean_zspec_sample(self,catalogue ,Qz_cut):
catalogue = catalogue[catalogue.w_Q_f_S15>=Qz_cut]
return catalogue
def _map_weight(self,Qz):
for key, value in self.weight_dict.items():
if key[0] <= Qz <= key[1]:
return value
def _get_loss_weights(self,catalogue):
catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight)
def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2):
max_distance_deg = max_distance_arcsec / 3600.0
gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION]
valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']]
kdtree = KDTree(gold_sample_radec)
distances, indices = kdtree.query(valid_sample_radec, k=1)
specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices]
zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)]
catalogue_valid['z_spec_gold'] = zs
return catalogue_valid
def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True,Qz_cut=1):
if only_zspec:
catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
catalogue = self._clean_zspec_sample(catalogue, Qz_cut=Qz_cut)
self.cat_train=catalogue
f, ferr = self._extract_fluxes(catalogue)
if extinction_corr==True:
f = self._correct_extinction(catalogue,f)
if convert_colors==True:
col, colerr = self._to_colors(f, ferr)
self.phot_train = col
self.photerr_train = colerr
else:
self.phot_train = f
self.photerr_train = ferr
self.target_z_train = catalogue['z_spec_S15'].values
self.target_qz_train = catalogue['w_Q_f_S15'].values
def _set_testing_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True):
if only_zspec:
catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
catalogue = self._clean_zspec_sample(catalogue, Qz_cut=1)
self.cat_test=catalogue
f, ferr = self._extract_fluxes(catalogue)
if extinction_corr==True:
f = self._correct_extinction(catalogue,f)
if convert_colors==True:
col, colerr = self._to_colors(f, ferr)
self.phot_test = col
self.photerr_test = colerr
else:
self.phot_test = f
self.photerr_test = ferr
self.target_z_test = catalogue['z_spec_S15'].values
def get_training_data(self):
return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train
def get_testing_data(self):
return self.phot_test, self.photerr_test, self.target_z_test
def get_VIS_mag(self, catalogue):
return catalogue[['MAG_VIS']].values
def plot_zdistribution(self, plot_test=False, bins=50):
_,_,specz = photoz_archive.get_training_data()
plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
if plot_test:
_,_,specz_test = photoz_archive.get_training_data()
plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel(r'Redshift', fontsize=14)
plt.ylabel('Counts', fontsize=14)
plt.show()