Spaces:
Running
on
A10G
Running
on
A10G
from lightning.pytorch.utilities import rank_zero_only | |
from fish_speech.utils import logger as log | |
def log_hyperparameters(object_dict: dict) -> None: | |
"""Controls which config parts are saved by lightning loggers. | |
Additionally saves: | |
- Number of model parameters | |
""" | |
hparams = {} | |
cfg = object_dict["cfg"] | |
model = object_dict["model"] | |
trainer = object_dict["trainer"] | |
if not trainer.logger: | |
log.warning("Logger not found! Skipping hyperparameter logging...") | |
return | |
hparams["model"] = cfg["model"] | |
# save number of model parameters | |
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) | |
hparams["model/params/trainable"] = sum( | |
p.numel() for p in model.parameters() if p.requires_grad | |
) | |
hparams["model/params/non_trainable"] = sum( | |
p.numel() for p in model.parameters() if not p.requires_grad | |
) | |
hparams["data"] = cfg["data"] | |
hparams["trainer"] = cfg["trainer"] | |
hparams["callbacks"] = cfg.get("callbacks") | |
hparams["extras"] = cfg.get("extras") | |
hparams["task_name"] = cfg.get("task_name") | |
hparams["tags"] = cfg.get("tags") | |
hparams["ckpt_path"] = cfg.get("ckpt_path") | |
hparams["seed"] = cfg.get("seed") | |
# send hparams to all loggers | |
for logger in trainer.loggers: | |
logger.log_hyperparams(hparams) | |