File size: 2,158 Bytes
4ff4028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import hydra
from omegaconf import DictConfig, OmegaConf
import rootutils
from dotenv import load_dotenv, find_dotenv

# Load environment variables
load_dotenv(find_dotenv(".env"))

# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


# Define a ModelCheckpoint class that takes in parameters as specified in cfg.callbacks.model_checkpoint
class ModelCheckpoint:
    def __init__(
        self,
        dirpath,
        filename,
        monitor,
        verbose=False,
        save_last=True,
        save_top_k=1,
        mode="max",
        auto_insert_metric_name=False,
        save_weights_only=False,
        every_n_train_steps=None,
        train_time_interval=None,
        every_n_epochs=None,
        save_on_train_epoch_end=None,
    ):
        self.dirpath = dirpath
        self.filename = filename
        self.monitor = monitor
        self.verbose = verbose
        self.save_last = save_last
        self.save_top_k = save_top_k
        self.mode = mode
        self.auto_insert_metric_name = auto_insert_metric_name
        self.save_weights_only = save_weights_only
        self.every_n_train_steps = every_n_train_steps
        self.train_time_interval = train_time_interval
        self.every_n_epochs = every_n_epochs
        self.save_on_train_epoch_end = save_on_train_epoch_end

    def display(self):
        print("Initialized ModelCheckpoint with the following configuration:")
        for attr, value in self.__dict__.items():
            print(f"{attr}: {value}")


# Define func4 to initialize the ModelCheckpoint class using cfg.callbacks.model_checkpoint
def func4(**kwargs):
    # Initialize ModelCheckpoint with the kwargs
    checkpoint = ModelCheckpoint(**kwargs)
    checkpoint.display()  # Display the configuration for confirmation


@hydra.main(config_path="../configs", config_name="train", version_base="1.1")
def hydra_test(cfg: DictConfig):
    # Print the full configuration
    print("Full Configuration:")

    # Call func4 with the model checkpoint configuration
    func4(**cfg.callbacks.model_checkpoint)


if __name__ == "__main__":
    hydra_test()