Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
from astropy.io import fits | |
import os | |
from astropy.table import Table | |
from scipy.spatial import KDTree | |
import matplotlib.pyplot as plt | |
from matplotlib import rcParams | |
rcParams["mathtext.fontset"] = "stix" | |
rcParams["font.family"] = "STIXGeneral" | |
class archive(): | |
def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, Qz_cut=1): | |
self.aperture = aperture | |
self.weight_dict={(-99,0.99):0, | |
(1,1.99):0.5, | |
(2,2.99):0.75, | |
(3,4):1, | |
(9,9.99):0.25, | |
(10,10.99):0, | |
(11,11.99):0.5, | |
(12,12.99):0.75, | |
(13,14):1, | |
(14.01,40):0 | |
} | |
filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits' | |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits' | |
filename_gold='Export_Gold_2023_07_03.csv' | |
hdu_list = fits.open(os.path.join(path,filename_calib)) | |
cat = Table(hdu_list[1].data).to_pandas() | |
hdu_list = fits.open(os.path.join(path,filename_valid)) | |
cat_test = Table(hdu_list[1].data).to_pandas() | |
self._get_loss_weights(cat) | |
self._get_loss_weights(cat_test) | |
gold_sample = pd.read_csv(os.path.join(path,filename_gold)) | |
#cat_test = self._match_gold_sample(cat_test,gold_sample) | |
if drop_stars==True: | |
cat = cat[cat.mu_class_L07==1] | |
cat_test = cat_test[cat_test.mu_class_L07==1] | |
if clean_photometry==True: | |
cat = self._clean_photometry(cat) | |
cat_test = self._clean_photometry(cat_test) | |
cat = cat[cat.w_Q_f_S15>0] | |
self._set_training_data(cat, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors,Qz_cut=Qz_cut) | |
self._set_testing_data(cat_test, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors) | |
self._get_loss_weights(cat) | |
def _extract_fluxes(self,catalogue): | |
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']] | |
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']] | |
f = catalogue[columns_f].values | |
ferr = catalogue[columns_ferr].values | |
return f, ferr | |
def _to_colors(self, flux, fluxerr): | |
""" Convert fluxes to colors""" | |
color = flux[:,:-1] / flux[:,1:] | |
color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2 | |
return color,color_err | |
def _clean_photometry(self,catalogue): | |
""" Drops all object with FLAG_PHOT!=0""" | |
catalogue = catalogue[catalogue['FLAG_PHOT']==0] | |
return catalogue | |
def _correct_extinction(self,catalogue, f): | |
"""Corrects for extinction""" | |
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']] | |
ext_correction = catalogue[ext_correction_cols].values | |
f = f * ext_correction | |
return f | |
def _take_only_zspec(self,catalogue,cat_flag=None): | |
"""Selects only galaxies with spectroscopic redshift""" | |
if cat_flag=='Calib': | |
catalogue = catalogue[catalogue.z_spec_S15>0] | |
elif cat_flag=='Valid': | |
catalogue = catalogue[catalogue.z_spec_S15>0] | |
return catalogue | |
def _clean_zspec_sample(self,catalogue ,Qz_cut): | |
catalogue = catalogue[catalogue.w_Q_f_S15>=Qz_cut] | |
return catalogue | |
def _map_weight(self,Qz): | |
for key, value in self.weight_dict.items(): | |
if key[0] <= Qz <= key[1]: | |
return value | |
def _get_loss_weights(self,catalogue): | |
catalogue['w_Q_f_S15'] = catalogue['Q_f_S15'].apply(self._map_weight) | |
def _match_gold_sample(self,catalogue_valid, catalogue_gold, max_distance_arcsec=2): | |
max_distance_deg = max_distance_arcsec / 3600.0 | |
gold_sample_radec = np.c_[catalogue_gold.RIGHT_ASCENSION,catalogue_gold.DECLINATION] | |
valid_sample_radec = np.c_[catalogue_valid['RA'],catalogue_valid['DEC']] | |
kdtree = KDTree(gold_sample_radec) | |
distances, indices = kdtree.query(valid_sample_radec, k=1) | |
specz_match_gold = catalogue_gold.FINAL_SPEC_Z.values[indices] | |
zs = [specz_match_gold[i] if distance < max_distance_deg else -99 for i, distance in enumerate(distances)] | |
catalogue_valid['z_spec_gold'] = zs | |
return catalogue_valid | |
def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True,Qz_cut=1): | |
if only_zspec: | |
catalogue = self._take_only_zspec(catalogue, cat_flag='Calib') | |
catalogue = self._clean_zspec_sample(catalogue, Qz_cut=Qz_cut) | |
self.cat_train=catalogue | |
f, ferr = self._extract_fluxes(catalogue) | |
if extinction_corr==True: | |
f = self._correct_extinction(catalogue,f) | |
if convert_colors==True: | |
col, colerr = self._to_colors(f, ferr) | |
self.phot_train = col | |
self.photerr_train = colerr | |
else: | |
self.phot_train = f | |
self.photerr_train = ferr | |
self.target_z_train = catalogue['z_spec_S15'].values | |
self.target_qz_train = catalogue['w_Q_f_S15'].values | |
def _set_testing_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True): | |
if only_zspec: | |
catalogue = self._take_only_zspec(catalogue, cat_flag='Valid') | |
catalogue = self._clean_zspec_sample(catalogue, Qz_cut=1) | |
self.cat_test=catalogue | |
f, ferr = self._extract_fluxes(catalogue) | |
if extinction_corr==True: | |
f = self._correct_extinction(catalogue,f) | |
if convert_colors==True: | |
col, colerr = self._to_colors(f, ferr) | |
self.phot_test = col | |
self.photerr_test = colerr | |
else: | |
self.phot_test = f | |
self.photerr_test = ferr | |
self.target_z_test = catalogue['z_spec_S15'].values | |
def get_training_data(self): | |
return self.phot_train, self.photerr_train, self.target_z_train, self.target_qz_train | |
def get_testing_data(self): | |
return self.phot_test, self.photerr_test, self.target_z_test | |
def get_VIS_mag(self, catalogue): | |
return catalogue[['MAG_VIS']].values | |
def plot_zdistribution(self, plot_test=False, bins=50): | |
_,_,specz = photoz_archive.get_training_data() | |
plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample') | |
if plot_test: | |
_,_,specz_test = photoz_archive.get_training_data() | |
plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--') | |
plt.xticks(fontsize=12) | |
plt.yticks(fontsize=12) | |
plt.xlabel(r'Redshift', fontsize=14) | |
plt.ylabel('Counts', fontsize=14) | |
plt.show() |