TEMPS / notebooks /Colourspace.py
lauracabayol's picture
update notebooks
11f8c36
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: temps
# language: python
# name: temps
# ---
# # FIGURE COLOURSPACE IN THE PAPER
# %load_ext autoreload
# %autoreload 2
import pandas as pd
import numpy as np
import os
from astropy.io import fits
from astropy.table import Table
import torch
from pathlib import Path
#matplotlib settings
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["mathtext.fontset"] = "stix"
rcParams["font.family"] = "STIXGeneral"
from temps.archive import Archive
from temps.utils import nmad
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule
def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
"""
Estimate a Self-Organizing Map (SOM) visualization from a DataFrame.
Parameters:
- df (pandas.DataFrame): Input DataFrame containing data for SOM estimation.
- plot_arg (str, optional): Column name to be used for plotting. Default is 'z'.
- nx (int, optional): Number of cells along the X-axis. Default is 40.
- ny (int, optional): Number of cells along the Y-axis. Default is 40.
Returns:
- som_data (numpy.ndarray): Estimated SOM visualization data.
"""
x_cells = np.arange(0, nx)
y_cells = np.arange(0, ny)
index_cell = np.arange(nx * ny)
cells = np.array(np.meshgrid(x_cells, y_cells)).T.reshape(-1, 2)
cells = pd.DataFrame(np.c_[cells[:, 0], cells[:, 1], index_cell], columns=['x_cell', 'y_cell', 'cell'])
if plot_arg == 'count':
som_vis = df.groupby('cell')['z'].count().reset_index().rename(columns={f'z': 'plot_som'})
else:
som_vis = df.groupby('cell')[f'{plot_arg}'].mean().reset_index().rename(columns={f'{plot_arg}': 'plot_som'})
som_data = som_vis.merge(cells, on='cell')
som_data = som_data.pivot(index='x_cell', columns='y_cell', values='plot_som')
return som_data
def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
"""
Plot the Self-Organizing Map (SOM) data.
Parameters:
- som_data (numpy.ndarray): The SOM data to be visualized.
- plot_arg (str, optional): The column name to be plotted. Default is 'z'.
- vmin (float, optional): Minimum value for color scaling. Default is 0.
- vmax (float, optional): Maximum value for color scaling. Default is 1.
Returns:
None
"""
plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap='viridis') # Choose an appropriate colormap
plt.colorbar(label=f'{plot_arg}') # Add a colorbar with a label
plt.xlabel(r'$x$ [pixel]', fontsize=14) # Add an appropriate X-axis label
plt.ylabel(r'$y$ [pixel]', fontsize=14) # Add an appropriate Y-axis label
plt.show()
# ### LOAD DATA
#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')
filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
# +
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
hdu_list = fits.open(parent_dir/filename_valid)
cat = Table(hdu_list[1].data).to_pandas()
cat = cat[cat['FLAG_PHOT']==0]
cat = cat[cat['mu_class_L07']==1]
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
cat = cat[cat['MAG_VIS']<25]
# -
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
ID = cat['ID']
VISmag = cat['MAG_VIS']
zsflag = cat['reliable_S15']
photoz_archive = Archive(path_calib = parent_dir/filename_calib,
path_valid = parent_dir/filename_valid,
only_zspec=False)
f = photoz_archive._extract_fluxes(catalogue= cat)
col = photoz_archive._to_colors(f)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# +
dfs = {}
for il, lab in enumerate(['z','L15','DA']):
nn_features = EncoderPhotometry()
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
nn_z = MeasureZ(num_gauss=6)
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
temps_module = TempsModule(nn_features, nn_z)
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
return_pz=True)
# Create a DataFrame with the desired columns
df = pd.DataFrame(np.c_[ID, VISmag,z,odds, ztarget,zsflag, specz_or_photo],
columns=['ID','VISmag','z', 'odds','ztarget','zsflag','S15_L15_flag'])
# Calculate additional columns or operations if needed
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
# Drop any rows with NaN values
df = df.dropna()
# Assign the DataFrame to a key in the dictionary
dfs[lab] = df
# -
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
#define here the directory containing the photometric catalogues
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
modules_dir = Path('../data/models/')
df_z = dfs['z']
df_z_DA = dfs['DA']
# ##### LOAD TRAIN SOM ON TRAINING DATA
df_som = pd.read_csv(parent_dir/'som_dataframe.csv', header = 0, sep =',')
df_z = df_z.merge(df_som, on = 'ID')
df_z_DA = df_z_DA.merge(df_som, on = 'ID')
# ##### APPLY CUTS FOR DIFFERENT SAMPLES
df_zspec = df_z[(df_z.S15_L15_flag==0) & (df_z.zsflag==1)]
df_l15 = df_z[(df_z.ztarget>0)]
df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.odds>df_l15_euclid['odds'].quantile(0.2)]
df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.odds>df_l15_euclid['odds'].quantile(0.2)]
# ## MAKE SOM PLOT
from mpl_toolkits.axes_grid1 import make_axes_locatable
# +
fig, axs = plt.subplots(6, 4, figsize=(13, 15), sharex=True, sharey=True, gridspec_kw={'hspace': 0.05, 'wspace': 0.06})
# Plot in the top row (axs[0, i])
#top row, spectroscopic sample
columns = ['ztarget','z','zwerr','count']
titles = [r'$z_{true}$ (A)',r'$z$ (B)',r'$z_{\rm error}$ (C)','Counts']
limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
for ii in range(4):
som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
im = axs[0,ii].imshow(som_data, vmin=limits[ii][0], vmax=limits[ii][1], cmap='viridis') # Choose an appropriate colormap
axs[0, ii].set_title(f'{titles[ii]}', fontsize=18)
if ii==0:
axs[0, 0].set_ylabel(r'$y$', fontsize=14)
elif ii==1:
cbar_ax = fig.add_axes([0.49, 0.11, 0.01, 0.77])
fig.colorbar(im, cax=cbar_ax)
elif ii==2:
cbar_ax = fig.add_axes([0.685, 0.11, 0.01, 0.77])
fig.colorbar(im, cax=cbar_ax)
elif ii==3:
cbar_ax = fig.add_axes([0.885, 0.11, 0.01, 0.77])
fig.colorbar(im, cax=cbar_ax)
for jj in range(4):
som_data = estimate_som_map(df_l15, plot_arg=columns[jj], nx=40, ny=40)
im = axs[1,jj].imshow(som_data, vmin=limits[jj][0], vmax=limits[jj][1], cmap='viridis') # Choose an appropriate colormap
#axs[1, jj].set_title(f'{titles[jj]}', fontsize=14)
#axs[1, jj].set_xlabel(r'$x$', fontsize=14)
for kk in range(4):
som_data = estimate_som_map(df_l15_DA, plot_arg=columns[kk], nx=40, ny=40)
im = axs[2,kk].imshow(som_data, vmin=limits[kk][0], vmax=limits[kk][1], cmap='viridis') # Choose an appropriate colormap
#axs[2, kk].set_title(f'{titles[kk]}', fontsize=14)
#axs[2, kk].set_xlabel(r'$x$', fontsize=14)
for rr in range(4):
som_data = estimate_som_map(df_l15_euclid_da, plot_arg=columns[rr], nx=40, ny=40)
im = axs[3,rr].imshow(som_data, vmin=limits[rr][0], vmax=limits[rr][1], cmap='viridis') # Choose an appropriate colormap
#axs[3, rr].set_title(f'{titles[rr]}', fontsize=14)
#axs[3, rr].set_xlabel(r'$x$', fontsize=14)
for ll in range(4):
som_data = estimate_som_map(df_l15_euclid_cut, plot_arg=columns[ll], nx=40, ny=40)
im = axs[4,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
#axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
axs[4, ll].set_xlabel(r'$x$', fontsize=14)
for ll in range(4):
som_data = estimate_som_map(df_l15_euclid_cut_da, plot_arg=columns[ll], nx=40, ny=40)
im = axs[5,ll].imshow(som_data, vmin=limits[ll][0], vmax=limits[ll][1], cmap='viridis') # Choose an appropriate colormap
#axs[4, ll].set_title(f'{titles[ll]}', fontsize=14)
axs[5, ll].set_xlabel(r'$x$', fontsize=14)
axs[0, 0].set_ylabel(r'$y$', fontsize=14)
axs[1, 0].set_ylabel(r'$y$', fontsize=14)
axs[2, 0].set_ylabel(r'$y$', fontsize=14)
axs[3, 0].set_ylabel(r'$y$', fontsize=14)
axs[4, 0].set_ylabel(r'$y$', fontsize=14)
axs[5, 0].set_ylabel(r'$y$', fontsize=14)
fig.text(0.09, 0.815, r'$z_{\rm s}$ samp. (1)', va='center', rotation='vertical', fontsize=16)
fig.text(0.09, 0.69, r'L15 samp. (2)', va='center', rotation='vertical', fontsize=16)
fig.text(0.09, 0.56, r'L15 samp. + DA (3)', va='center', rotation='vertical', fontsize=14)
fig.text(0.09, 0.44, r'$Euclid$ samp. + DA (4)', va='center', rotation='vertical', fontsize=14)
fig.text(0.09, 0.3, r'$Euclid$ samp. + QC (5)', va='center', rotation='vertical', fontsize=14)
fig.text(0.09, 0.17, r'(5) + DA ', va='center', rotation='vertical', fontsize=13)
plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
# -