Spaces:
Runtime error
Runtime error
File size: 3,708 Bytes
41d1e96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import numpy as np
import torch
import argparse
from loguru import logger
from astropy.table import Table
import pandas as pd
from pathlib import Path
from temps.archive import Archive
from temps.temps import TempsModule
from temps.temps_arch import EncoderPhotometry, MeasureZ
def train(config: dict) -> None:
"""
Trains the TempsModule using photometry data.
Parameters:
-----------
config : dict
Configuration dictionary containing paths, model hyperparameters, and settings.
Returns:
--------
None
"""
# Paths
path_calib = Path(config["path_calib"])
path_valid = Path(config["path_valid"])
output_model = Path(config["output_model"])
# Initialize neural network modules for photometry features and redshift measurement
nn_features = EncoderPhotometry()
nn_z = MeasureZ(num_gauss=6) # Example for Gaussian mixture model with 6 components
# Initialize the TempsModule with the defined neural networks
temps_module = TempsModule(nn_features, nn_z)
# Retrieve photometry and spectroscopic data for training
photoz_archive = Archive(path_calib=path_calib,
path_valid=path_valid,
drop_stars=False,
clean_photometry=False,
only_zspec=config["only_zs"],
columns_photometry=config["bands"])
f, specz, VIS_mag, f_DA, z_DA = photoz_archive.get_training_data()
# Train the TempsModule
logger.info("Starting model training...")
temps_module.train(
input_data=f,
input_data_da=f_DA,
target_data=specz,
nepochs=config["hyperparams"]["nepochs"],
step_size=config["hyperparams"]["nepochs"],
val_fraction=0.1, # Validation fraction of 10%
lr=config["hyperparams"]["learning_rate"]
)
logger.info("Model training complete.")
# Save the trained models
logger.info("Saving trained models...")
torch.save(temps_module.modelF.state_dict(), output_model / "modelF_zs_test.pt")
torch.save(temps_module.modelZ.state_dict(), output_model / "modelZ_zs_test.pt")
logger.info("Models saved at: {}", output_model)
def main() -> None:
"""
Main entry point for the training script.
Reads the configuration file, calls the `train` function, and handles logging.
"""
# Get command-line arguments
args = get_args()
# Load the configuration from the provided path
config_path = args.config_path
logger.info("Loading configuration from {}", config_path)
# Read the configuration file (assuming YAML format)
config = read_config(config_path)
# Call the train function
train(config)
def get_args() -> argparse.Namespace:
"""
Parses command-line arguments for the script.
Returns:
--------
argparse.Namespace
Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description="Training script for TempsModule")
parser.add_argument(
"--config-path",
type=Path,
required=True,
help="Path to the configuration file (YAML format)"
)
return parser.parse_args()
def read_config(config_path: Path) -> dict:
"""
Reads the configuration from a YAML file.
Parameters:
-----------
config_path : Path
Path to the configuration YAML file.
Returns:
--------
dict
Parsed configuration dictionary.
"""
import yaml
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
if __name__ == "__main__":
main()
|