Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import argparse | |
from loguru import logger | |
from astropy.table import Table | |
import pandas as pd | |
from pathlib import Path | |
from temps.archive import Archive | |
from temps.temps import TempsModule | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
def train(config: dict) -> None: | |
""" | |
Trains the TempsModule using photometry data. | |
Parameters: | |
----------- | |
config : dict | |
Configuration dictionary containing paths, model hyperparameters, and settings. | |
Returns: | |
-------- | |
None | |
""" | |
# Paths | |
path_calib = Path(config["path_calib"]) | |
path_valid = Path(config["path_valid"]) | |
output_model = Path(config["output_model"]) | |
# Initialize neural network modules for photometry features and redshift measurement | |
nn_features = EncoderPhotometry() | |
nn_z = MeasureZ(num_gauss=6) # Example for Gaussian mixture model with 6 components | |
# Initialize the TempsModule with the defined neural networks | |
temps_module = TempsModule(nn_features, nn_z) | |
# Retrieve photometry and spectroscopic data for training | |
photoz_archive = Archive( | |
path_calib=path_calib, | |
path_valid=path_valid, | |
drop_stars=False, | |
clean_photometry=False, | |
only_zspec=config["only_zs"], | |
columns_photometry=config["bands"], | |
) | |
f, specz, VIS_mag, f_DA, z_DA = photoz_archive.get_training_data() | |
# Train the TempsModule | |
logger.info("Starting model training...") | |
temps_module.train( | |
input_data=f, | |
input_data_da=f_DA, | |
target_data=specz, | |
nepochs=config["hyperparams"]["nepochs"], | |
step_size=config["hyperparams"]["nepochs"], | |
val_fraction=0.1, # Validation fraction of 10% | |
lr=config["hyperparams"]["learning_rate"], | |
) | |
logger.info("Model training complete.") | |
# Save the trained models | |
logger.info("Saving trained models...") | |
torch.save(temps_module.modelF.state_dict(), output_model / "modelF_zs_test.pt") | |
torch.save(temps_module.modelZ.state_dict(), output_model / "modelZ_zs_test.pt") | |
logger.info("Models saved at: {}", output_model) | |
def main() -> None: | |
""" | |
Main entry point for the training script. | |
Reads the configuration file, calls the `train` function, and handles logging. | |
""" | |
# Get command-line arguments | |
args = get_args() | |
# Load the configuration from the provided path | |
config_path = args.config_path | |
logger.info("Loading configuration from {}", config_path) | |
# Read the configuration file (assuming YAML format) | |
config = read_config(config_path) | |
# Call the train function | |
train(config) | |
def get_args() -> argparse.Namespace: | |
""" | |
Parses command-line arguments for the script. | |
Returns: | |
-------- | |
argparse.Namespace | |
Parsed command-line arguments. | |
""" | |
parser = argparse.ArgumentParser(description="Training script for TempsModule") | |
parser.add_argument( | |
"--config-path", | |
type=Path, | |
required=True, | |
help="Path to the configuration file (YAML format)", | |
) | |
return parser.parse_args() | |
def read_config(config_path: Path) -> dict: | |
""" | |
Reads the configuration from a YAML file. | |
Parameters: | |
----------- | |
config_path : Path | |
Path to the configuration YAML file. | |
Returns: | |
-------- | |
dict | |
Parsed configuration dictionary. | |
""" | |
import yaml | |
with open(config_path, "r") as file: | |
config = yaml.safe_load(file) | |
return config | |
if __name__ == "__main__": | |
main() | |