TEMPS / train.py
Laura Cabayol Garcia
running precommit
668e440
raw
history blame
3.57 kB
import numpy as np
import torch
import argparse
from loguru import logger
from astropy.table import Table
import pandas as pd
from pathlib import Path
from temps.archive import Archive
from temps.temps import TempsModule
from temps.temps_arch import EncoderPhotometry, MeasureZ
def train(config: dict) -> None:
"""
Trains the TempsModule using photometry data.
Parameters:
-----------
config : dict
Configuration dictionary containing paths, model hyperparameters, and settings.
Returns:
--------
None
"""
# Paths
path_calib = Path(config["path_calib"])
path_valid = Path(config["path_valid"])
output_model = Path(config["output_model"])
# Initialize neural network modules for photometry features and redshift measurement
nn_features = EncoderPhotometry()
nn_z = MeasureZ(num_gauss=6) # Example for Gaussian mixture model with 6 components
# Initialize the TempsModule with the defined neural networks
temps_module = TempsModule(nn_features, nn_z)
# Retrieve photometry and spectroscopic data for training
photoz_archive = Archive(
path_calib=path_calib,
path_valid=path_valid,
drop_stars=False,
clean_photometry=False,
only_zspec=config["only_zs"],
columns_photometry=config["bands"],
)
f, specz, VIS_mag, f_DA, z_DA = photoz_archive.get_training_data()
# Train the TempsModule
logger.info("Starting model training...")
temps_module.train(
input_data=f,
input_data_da=f_DA,
target_data=specz,
nepochs=config["hyperparams"]["nepochs"],
step_size=config["hyperparams"]["nepochs"],
val_fraction=0.1, # Validation fraction of 10%
lr=config["hyperparams"]["learning_rate"],
)
logger.info("Model training complete.")
# Save the trained models
logger.info("Saving trained models...")
torch.save(temps_module.modelF.state_dict(), output_model / "modelF_zs_test.pt")
torch.save(temps_module.modelZ.state_dict(), output_model / "modelZ_zs_test.pt")
logger.info("Models saved at: {}", output_model)
def main() -> None:
"""
Main entry point for the training script.
Reads the configuration file, calls the `train` function, and handles logging.
"""
# Get command-line arguments
args = get_args()
# Load the configuration from the provided path
config_path = args.config_path
logger.info("Loading configuration from {}", config_path)
# Read the configuration file (assuming YAML format)
config = read_config(config_path)
# Call the train function
train(config)
def get_args() -> argparse.Namespace:
"""
Parses command-line arguments for the script.
Returns:
--------
argparse.Namespace
Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description="Training script for TempsModule")
parser.add_argument(
"--config-path",
type=Path,
required=True,
help="Path to the configuration file (YAML format)",
)
return parser.parse_args()
def read_config(config_path: Path) -> dict:
"""
Reads the configuration from a YAML file.
Parameters:
-----------
config_path : Path
Path to the configuration YAML file.
Returns:
--------
dict
Parsed configuration dictionary.
"""
import yaml
with open(config_path, "r") as file:
config = yaml.safe_load(file)
return config
if __name__ == "__main__":
main()