|
import logging |
|
import os |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
from utils.utils_verbalisation_module import save_json |
|
from pytorch_lightning.utilities import rank_zero_info |
|
|
|
def count_trainable_parameters(model): |
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
return params |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class Seq2SeqLoggingCallback(pl.Callback): |
|
def on_batch_end(self, trainer, pl_module): |
|
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} |
|
pl_module.logger.log_metrics(lrs) |
|
|
|
@rank_zero_only |
|
def _write_logs( |
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True |
|
) -> None: |
|
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") |
|
metrics = trainer.callback_metrics |
|
|
|
new_metrics = {} |
|
ms = ["log", "progress_bar", "preds"] |
|
for k, v in metrics.items(): |
|
ver = True |
|
for m in ms: |
|
if m in k: |
|
ver = False |
|
break |
|
if ver: |
|
new_metrics[k] = v |
|
|
|
print(new_metrics) |
|
trainer.logger.log_metrics(new_metrics) |
|
|
|
od = Path(pl_module.hparams.output_dir) |
|
if type_path == "test": |
|
results_file = od / "test_results.txt" |
|
generations_file = od / "test_generations.txt" |
|
else: |
|
|
|
|
|
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" |
|
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" |
|
results_file.parent.mkdir(exist_ok=True) |
|
generations_file.parent.mkdir(exist_ok=True) |
|
with open(results_file, "a+") as writer: |
|
for key in sorted(metrics): |
|
if key in ["log", "progress_bar", "preds"]: |
|
continue |
|
try: |
|
val = metrics[key] |
|
if isinstance(val, torch.Tensor): |
|
val = val.item() |
|
msg = f"{key}: {val:.6f}\n" |
|
writer.write(msg) |
|
except: |
|
pass |
|
|
|
if not save_generations: |
|
return |
|
|
|
if "preds" in metrics: |
|
content = "\n".join(metrics["preds"]) |
|
generations_file.open("w+").write(content) |
|
|
|
@rank_zero_only |
|
def on_train_start(self, trainer, pl_module): |
|
try: |
|
npars = pl_module.model.model.num_parameters() |
|
except AttributeError: |
|
npars = pl_module.model.num_parameters() |
|
|
|
n_trainable_pars = count_trainable_parameters(pl_module) |
|
|
|
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) |
|
|
|
@rank_zero_only |
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
|
save_json(pl_module.metrics, pl_module.metrics_save_path) |
|
return self._write_logs(trainer, pl_module, "test") |
|
|
|
@rank_zero_only |
|
def on_validation_end(self, trainer: pl.Trainer, pl_module): |
|
save_json(pl_module.metrics, pl_module.metrics_save_path) |
|
|
|
rank_zero_info("***** Validation results *****") |
|
metrics = trainer.callback_metrics |
|
|
|
for key in sorted(metrics): |
|
if key not in ["log", "progress_bar", "preds"]: |
|
rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) |
|
|
|
|
|
|
|
|
|
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): |
|
"""Saves the best model by validation ROUGE2 score.""" |
|
if metric == "rouge2": |
|
exp = "{val_avg_rouge2:.4f}-{step_count}" |
|
elif metric == "bleu": |
|
exp = "{val_avg_bleu:.4f}-{step_count}" |
|
elif metric == "loss": |
|
exp = "{val_avg_loss:.4f}-{step_count}" |
|
else: |
|
raise NotImplementedError( |
|
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function." |
|
) |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
filepath=os.path.join(output_dir, exp), |
|
monitor=f"val_{metric}", |
|
mode="min" if "loss" in metric else "max", |
|
save_top_k=save_top_k, |
|
period=0, |
|
) |
|
return checkpoint_callback |
|
|
|
|
|
def get_early_stopping_callback(metric, patience): |
|
return EarlyStopping( |
|
monitor=f"val_{metric}", |
|
mode="min" if "loss" in metric else "max", |
|
patience=patience, |
|
verbose=True, |
|
) |
|
|