File size: 3,708 Bytes
41d1e96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import torch
import argparse
from loguru import logger
from astropy.table import Table
import pandas as pd
from pathlib import Path

from temps.archive import Archive
from temps.temps import TempsModule
from temps.temps_arch import EncoderPhotometry, MeasureZ

def train(config: dict) -> None:
    """
    Trains the TempsModule using photometry data.

    Parameters:
    -----------
    config : dict
        Configuration dictionary containing paths, model hyperparameters, and settings.

    Returns:
    --------
    None
    """

    # Paths
    path_calib = Path(config["path_calib"])
    path_valid = Path(config["path_valid"])
    output_model = Path(config["output_model"])

    # Initialize neural network modules for photometry features and redshift measurement
    nn_features = EncoderPhotometry()
    nn_z = MeasureZ(num_gauss=6)  # Example for Gaussian mixture model with 6 components

    # Initialize the TempsModule with the defined neural networks
    temps_module = TempsModule(nn_features, nn_z)

    # Retrieve photometry and spectroscopic data for training
    photoz_archive = Archive(path_calib=path_calib,
                             path_valid=path_valid,
                             drop_stars=False,
                             clean_photometry=False,
                             only_zspec=config["only_zs"],
                             columns_photometry=config["bands"])
                             
    f, specz, VIS_mag, f_DA, z_DA = photoz_archive.get_training_data()

    # Train the TempsModule
    logger.info("Starting model training...")
    temps_module.train(
        input_data=f, 
        input_data_da=f_DA,
        target_data=specz, 
        nepochs=config["hyperparams"]["nepochs"],
        step_size=config["hyperparams"]["nepochs"], 
        val_fraction=0.1,  # Validation fraction of 10%
        lr=config["hyperparams"]["learning_rate"]
    )
    logger.info("Model training complete.")

    # Save the trained models
    logger.info("Saving trained models...")
    torch.save(temps_module.modelF.state_dict(), output_model / "modelF_zs_test.pt")
    torch.save(temps_module.modelZ.state_dict(), output_model / "modelZ_zs_test.pt")
    logger.info("Models saved at: {}", output_model)

def main() -> None:
    """
    Main entry point for the training script.

    Reads the configuration file, calls the `train` function, and handles logging.
    """
    # Get command-line arguments
    args = get_args()

    # Load the configuration from the provided path
    config_path = args.config_path
    logger.info("Loading configuration from {}", config_path)
    
    # Read the configuration file (assuming YAML format)
    config = read_config(config_path)
    
    # Call the train function
    train(config)

def get_args() -> argparse.Namespace:
    """
    Parses command-line arguments for the script.

    Returns:
    --------
    argparse.Namespace
        Parsed command-line arguments.
    """
    parser = argparse.ArgumentParser(description="Training script for TempsModule")
    
    parser.add_argument(
        "--config-path",
        type=Path,
        required=True,
        help="Path to the configuration file (YAML format)"
    )
    
    return parser.parse_args()

def read_config(config_path: Path) -> dict:
    """
    Reads the configuration from a YAML file.

    Parameters:
    -----------
    config_path : Path
        Path to the configuration YAML file.

    Returns:
    --------
    dict
        Parsed configuration dictionary.
    """
    import yaml
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    
    return config

if __name__ == "__main__":
    main()