Spaces:
Runtime error
Runtime error
File size: 10,832 Bytes
d307831 21a7d1b d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 c9354dd d307831 e52c4aa c9354dd 21a7d1b c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa c9354dd e52c4aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from temps.utils import nmad, sigma68
from scipy import stats
from typing import List, Optional, Dict
def plot_photoz(
df_list: List[pd.DataFrame],
nbins: int,
xvariable: str,
metric: str,
type_bin: str = "bin",
label_list: Optional[List[str]] = None,
samp: str = "zs",
save: bool = False,
) -> None:
"""
Plot photo-z metrics for multiple dataframes.
Parameters:
- df_list (List[pd.DataFrame]): List of dataframes containing data for plotting.
- nbins (int): Number of bins for the histogram.
- xvariable (str): Variable to plot on the x-axis.
- metric (str): Metric to plot (e.g., 'sig68', 'bias', 'nmad', 'outliers').
- type_bin (str, optional): Type of binning ('bin' or 'cum'). Default is 'bin'.
- label_list (Optional[List[str]], optional): List of labels for each dataframe. Default is None.
- samp (str, optional): Sample label for saving. Default is 'zs'.
- save (bool, optional): If True, save the plot to a file. Default is False.
Returns:
None
"""
# Plot properties
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 12
# Set x-axis label based on variable
xvariable_lab = "VIS" if xvariable == "VISmag" else r"$z_{\rm s}$"
# Calculate bin edges
bin_edges = stats.mstats.mquantiles(
df_list[0][xvariable].values, np.linspace(0.05, 1, nbins)
)
cmap = plt.get_cmap("Dark2")
# Create subplots
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(8, 8), gridspec_kw={"height_ratios": [3, 1]}
)
ydata_dict: Dict[str, List[float]] = {}
# Loop through dataframes and calculate metrics
for i, df in enumerate(df_list):
ydata, xlab = [], []
label = label_list[i]
label_lab = {
"zs": r"$z_{\rm s}$",
"zs+L15": r"$z_{\rm s}$+L15",
"TEMPS": "TEMPS",
}.get(label, label)
for k in range(len(bin_edges) - 1):
edge_min = bin_edges[k]
edge_max = bin_edges[k + 1]
mean_mag = (edge_max + edge_min) / 2
df_plot = (
df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
if type_bin == "bin"
else df[(df[xvariable] < edge_max)]
)
xlab.append(mean_mag)
if metric == "sig68":
ydata.append(sigma68(df_plot.zwerr))
elif metric == "bias":
ydata.append(np.mean(df_plot.zwerr))
elif metric == "nmad":
ydata.append(nmad(df_plot.zwerr))
elif metric == "outliers":
ydata.append(
len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot) * 100
)
ydata_dict[f"{i}"] = ydata
color = cmap(i)
ax1.plot(
xlab,
ydata,
marker=".",
lw=1,
label=label_lab,
color=color,
ls=["--", ":", "-"][i],
)
ax1.set_ylabel(f"{metric} $[\Delta z]$", fontsize=18)
ax1.grid(False)
ax1.legend()
# Plot ratios
ax2.plot(
xlab,
np.array(ydata_dict["1"]) / np.array(ydata_dict["0"]),
marker=".",
color=cmap(1),
)
ax2.plot(
xlab,
np.array(ydata_dict["2"]) / np.array(ydata_dict["0"]),
marker=".",
color=cmap(2),
)
ax2.set_ylabel(r"Method $X$ / $z_{\rm z}$", fontsize=14)
ax2.set_xlabel(f"{xvariable_lab}", fontsize=16)
ax2.grid(True)
if save:
plt.savefig(f"{metric}_{xvariable}_{samp}.pdf", dpi=300, bbox_inches="tight")
plt.show()
def plot_pz(m: int, pz: np.ndarray, specz: float) -> None:
"""
Plot the Probability Density Function (PDF) for a given model and compare it with the spectroscopic redshift.
Parameters:
- m (int): Index for the model.
- pz (np.ndarray): Probability density function values.
- specz (float): Spectroscopic redshift value.
Returns:
None
"""
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(np.linspace(0, 4, 1000), pz[m], label="PDF", color="navy")
ax.axvline(specz[m], color="black", linestyle="--", label=r"$z_{\rm s}$")
ax.set_xlabel(r"$z$", fontsize=18)
ax.set_ylabel("Probability Density", fontsize=16)
ax.legend(fontsize=18)
plt.show()
def plot_zdistribution(archive, plot_test: bool = False, bins: int = 50) -> None:
"""
Plot the distribution of redshifts for training and optionally test samples.
Parameters:
- archive: Data archive object containing the training data.
- plot_test (bool, optional): If True, plot test sample distribution. Default is False.
- bins (int, optional): Number of histogram bins. Default is 50.
Returns:
None
"""
_, _, specz = archive.get_training_data()
plt.hist(specz, bins=bins, histtype="step", color="navy", label=r"Training sample")
if plot_test:
_, _, specz_test = archive.get_training_data()
plt.hist(
specz_test,
bins=bins,
histtype="step",
color="goldenrod",
label=r"Test sample",
linestyle="--",
)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel(r"Redshift", fontsize=14)
plt.ylabel("Counts", fontsize=14)
plt.legend()
plt.show()
def plot_som_map(
som_data: np.ndarray, plot_arg: str = "z", vmin: float = 0, vmax: float = 1
) -> None:
"""
Plot the Self-Organizing Map (SOM) data.
Parameters:
- som_data (numpy.ndarray): The SOM data to be visualized.
- plot_arg (str, optional): The column name to be plotted. Default is 'z'.
- vmin (float, optional): Minimum value for color scaling. Default is 0.
- vmax (float, optional): Maximum value for color scaling. Default is 1.
Returns:
None
"""
plt.imshow(som_data, vmin=vmin, vmax=vmax, cmap="viridis")
plt.colorbar(label=f"{plot_arg}")
plt.xlabel(r"$x$ [pixel]", fontsize=14)
plt.ylabel(r"$y$ [pixel]", fontsize=14)
plt.show()
def plot_PIT(
pit_list_1: List[float],
pit_list_2: Optional[List[float]] = None,
pit_list_3: Optional[List[float]] = None,
sample: str = "specz",
labels: Optional[List[str]] = None,
save: bool = True,
) -> None:
"""
Plot Probability Integral Transform (PIT) values for given lists.
Parameters:
- pit_list_1 (List[float]): First list of PIT values.
- pit_list_2 (Optional[List[float]], optional): Second list of PIT values. Default is None.
- pit_list_3 (Optional[List[float]], optional): Third list of PIT values. Default is None.
- sample (str, optional): Sample label for saving. Default is 'specz'.
- labels (Optional[List[str]], optional): List of labels for each PIT list. Default is None.
- save (bool, optional): If True, save the plot to a file. Default is True.
Returns:
None
"""
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 12
fig, ax = plt.subplots(figsize=(8, 6))
kwargs = dict(bins=30, histtype="step", density=True, range=(0, 1))
cmap = plt.get_cmap("Dark2")
# Create a histogram
ax.hist(pit_list_1, color=cmap(0), linestyle="--", **kwargs, label=labels[0])
if pit_list_2 is not None:
ax.hist(pit_list_2, color=cmap(1), linestyle="--", **kwargs, label=labels[1])
if pit_list_3 is not None:
ax.hist(pit_list_3, color=cmap(2), linestyle="--", **kwargs, label=labels[2])
ax.set_xlabel("PIT values", fontsize=14)
ax.set_ylabel("Normalized Counts", fontsize=14)
ax.legend(fontsize=12)
if save:
plt.savefig(f"PIT_{sample}.pdf", dpi=300, bbox_inches="tight")
plt.show()
def plot_outlier_ratio(
outliers: np.ndarray, num_samp: int = 100, plot_mean: bool = True
) -> None:
"""
Plot the outlier ratio as a function of the number of samples.
Parameters:
- outliers (np.ndarray): Outlier ratio data.
- num_samp (int, optional): Number of samples for plotting. Default is 100.
- plot_mean (bool, optional): If True, plot the mean of outliers. Default is True.
Returns:
None
"""
plt.figure(figsize=(10, 6))
plt.plot(np.arange(1, num_samp + 1), outliers[:num_samp], label="Outlier Ratio")
if plot_mean:
plt.axhline(
np.mean(outliers), color="red", linestyle="--", label="Mean Outlier Ratio"
)
plt.xlabel("Number of Samples", fontsize=14)
plt.ylabel("Outlier Ratio", fontsize=14)
plt.legend()
plt.grid()
plt.show()
def plot_crps(
crps_list_1: List[float],
crps_list_2: Optional[List[float]] = None,
crps_list_3: Optional[List[float]] = None,
labels: Optional[List[str]] = None,
sample: str = "specz",
save: bool = True,
) -> None:
# Create a figure and axis
# plot properties
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 12
fig, ax = plt.subplots(figsize=(8, 6))
cmap = plt.get_cmap("Dark2")
kwargs = dict(bins=50, histtype="step", density=True, range=(0, 1))
# Create a histogram
hist, bins, _ = ax.hist(
crps_list_1, color=cmap(0), ls="--", **kwargs, label=labels[0]
)
if crps_list_2 is not None:
hist, bins, _ = ax.hist(
crps_list_2, color=cmap(1), ls=":", **kwargs, label=labels[1]
)
if crps_list_3 is not None:
hist, bins, _ = ax.hist(
crps_list_3, color=cmap(2), ls="-", **kwargs, label=labels[2]
)
# Add labels and a title
ax.set_xlabel("CRPS Scores", fontsize=18)
ax.set_ylabel("Frequency", fontsize=18)
# Add grid lines
ax.grid(True, linestyle="--", alpha=0.7)
# Customize the x-axis
ax.set_xlim(0, 0.5)
# Make ticks larger
ax.tick_params(axis="both", which="major", labelsize=14)
# Calculate the mean CRPS value
mean_crps_1 = round(np.nanmean(crps_list_1), 2)
mean_crps_2 = round(np.nanmean(crps_list_2), 2)
mean_crps_3 = round(np.nanmean(crps_list_3), 2)
# Add the mean CRPS value at the top-left corner
ax.annotate(
f"Mean CRPS {labels[0]}: {mean_crps_1}",
xy=(0.57, 0.9),
xycoords="axes fraction",
fontsize=14,
color=cmap(0),
)
ax.annotate(
f"Mean CRPS {labels[1]}: {mean_crps_2}",
xy=(0.57, 0.85),
xycoords="axes fraction",
fontsize=14,
color=cmap(1),
)
ax.annotate(
f"Mean CRPS {labels[2]}: {mean_crps_3}",
xy=(0.57, 0.8),
xycoords="axes fraction",
fontsize=14,
color=cmap(2),
)
if save == True:
plt.savefig(f"{sample}_CRPS.pdf", bbox_inches="tight")
# Show the plot
plt.show()
|