import numpy as np import torch import argparse from loguru import logger from astropy.table import Table import pandas as pd from pathlib import Path from temps.archive import Archive from temps.temps import TempsModule from temps.temps_arch import EncoderPhotometry, MeasureZ def train(config: dict) -> None: """ Trains the TempsModule using photometry data. Parameters: ----------- config : dict Configuration dictionary containing paths, model hyperparameters, and settings. Returns: -------- None """ # Paths path_calib = Path(config["path_calib"]) path_valid = Path(config["path_valid"]) output_model = Path(config["output_model"]) # Initialize neural network modules for photometry features and redshift measurement nn_features = EncoderPhotometry() nn_z = MeasureZ(num_gauss=6) # Example for Gaussian mixture model with 6 components # Initialize the TempsModule with the defined neural networks temps_module = TempsModule(nn_features, nn_z) # Retrieve photometry and spectroscopic data for training photoz_archive = Archive(path_calib=path_calib, path_valid=path_valid, drop_stars=False, clean_photometry=False, only_zspec=config["only_zs"], columns_photometry=config["bands"]) f, specz, VIS_mag, f_DA, z_DA = photoz_archive.get_training_data() # Train the TempsModule logger.info("Starting model training...") temps_module.train( input_data=f, input_data_da=f_DA, target_data=specz, nepochs=config["hyperparams"]["nepochs"], step_size=config["hyperparams"]["nepochs"], val_fraction=0.1, # Validation fraction of 10% lr=config["hyperparams"]["learning_rate"] ) logger.info("Model training complete.") # Save the trained models logger.info("Saving trained models...") torch.save(temps_module.modelF.state_dict(), output_model / "modelF_zs_test.pt") torch.save(temps_module.modelZ.state_dict(), output_model / "modelZ_zs_test.pt") logger.info("Models saved at: {}", output_model) def main() -> None: """ Main entry point for the training script. Reads the configuration file, calls the `train` function, and handles logging. """ # Get command-line arguments args = get_args() # Load the configuration from the provided path config_path = args.config_path logger.info("Loading configuration from {}", config_path) # Read the configuration file (assuming YAML format) config = read_config(config_path) # Call the train function train(config) def get_args() -> argparse.Namespace: """ Parses command-line arguments for the script. Returns: -------- argparse.Namespace Parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Training script for TempsModule") parser.add_argument( "--config-path", type=Path, required=True, help="Path to the configuration file (YAML format)" ) return parser.parse_args() def read_config(config_path: Path) -> dict: """ Reads the configuration from a YAML file. Parameters: ----------- config_path : Path Path to the configuration YAML file. Returns: -------- dict Parsed configuration dictionary. """ import yaml with open(config_path, 'r') as file: config = yaml.safe_load(file) return config if __name__ == "__main__": main()