Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
raw
history blame
No virus
16.6 kB
import os
from typing import Dict
from diacritization_evaluation import der, wer
import torch
from torch import nn
from torch import optim
from torch.cuda.amp import autocast
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm
from tqdm import trange
from .config_manager import ConfigManager
from dataset import load_iterators
from diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer, GPTDiacritizer
from poetry_diacritizer.util.learning_rates import LearningRateDecay
from poetry_diacritizer.options import OptimizerType
from poetry_diacritizer.util.utils import (
categorical_accuracy,
count_parameters,
initialize_weights,
plot_alignment,
repeater,
)
import wandb
wandb.login()
class Trainer:
def run(self):
raise NotImplementedError
class GeneralTrainer(Trainer):
def __init__(self, config_path: str, model_kind: str, model_desc: str) -> None:
self.config_path = config_path
self.model_kind = model_kind
self.config_manager = ConfigManager(
config_path=config_path, model_kind=model_kind
)
self.config = self.config_manager.config
self.losses = []
self.lr = 0
self.pad_idx = 0
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
self.set_device()
self.config_manager.create_remove_dirs()
self.text_encoder = self.config_manager.text_encoder
self.start_symbol_id = self.text_encoder.start_symbol_id
self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir)
if model_desc == "":
model_desc = self.model_kind
wandb.init(project="diacratization", name=model_desc, config=self.config)
self.model = self.config_manager.get_model()
self.optimizer = self.get_optimizer()
self.model = self.model.to(self.device)
self.load_model(model_path=self.config.get("train_resume_model_path"))
self.load_diacritizer()
self.initialize_model()
self.print_config()
def set_device(self):
if self.config.get("device"):
self.device = self.config["device"]
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def print_config(self):
self.config_manager.dump_config()
self.config_manager.print_config()
if self.global_step > 1:
print(f"loaded form {self.global_step}")
parameters_count = count_parameters(self.model)
print(f"The model has {parameters_count} trainable parameters parameters")
def load_diacritizer(self):
if self.model_kind in ["cbhg", "baseline"]:
self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind)
elif self.model_kind in ["seq2seq", "tacotron_based"]:
self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind)
elif self.model_kind in ["gpt"]:
self.diacritizer = GPTDiacritizer(self.config_path, self.model_kind)
def initialize_model(self):
if self.global_step > 1:
return
if self.model_kind == "transformer":
print("Initializing using xavier_uniform_")
self.model.apply(initialize_weights)
def print_losses(self, step_results, tqdm):
self.summary_manager.add_scalar(
"loss/loss", step_results["loss"], global_step=self.global_step
)
tqdm.display(f"loss: {step_results['loss']}", pos=3)
for pos, n_steps in enumerate(self.config["n_steps_avg_losses"]):
if len(self.losses) > n_steps:
self.summary_manager.add_scalar(
f"loss/loss-{n_steps}",
sum(self.losses[-n_steps:]) / n_steps,
global_step=self.global_step,
)
tqdm.display(
f"{n_steps}-steps average loss: {sum(self.losses[-n_steps:]) / n_steps}",
pos=pos + 4,
)
def evaluate(self, iterator, tqdm, use_target=True, log = True):
epoch_loss = 0
epoch_acc = 0
self.model.eval()
tqdm.set_description(f"Eval: {self.global_step}")
with torch.no_grad():
for batch_inputs in iterator:
batch_inputs["src"] = batch_inputs["src"].to(self.device)
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
if use_target:
batch_inputs["target"] = batch_inputs["target"].to(self.device)
else:
batch_inputs["target"] = None
outputs = self.model(
src=batch_inputs["src"],
target=batch_inputs["target"],
lengths=batch_inputs["lengths"],
)
predictions = outputs["diacritics"]
predictions = predictions.view(-1, predictions.shape[-1])
targets = batch_inputs["target"]
targets = targets.view(-1)
loss = self.criterion(predictions, targets.to(self.device))
acc = categorical_accuracy(
predictions, targets.to(self.device), self.pad_idx
)
epoch_loss += loss.item()
epoch_acc += acc.item()
if log:
wandb.log({"evaluate_loss": loss.item(), "evaluate_acc": acc.item()})
tqdm.update()
tqdm.reset()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate_with_error_rates(self, iterator, tqdm, log = True):
all_orig = []
all_predicted = []
results = {}
self.diacritizer.set_model(self.model)
evaluated_batches = 0
tqdm.set_description(f"Calculating DER/WER {self.global_step}: ")
for i, batch in enumerate(iterator):
if evaluated_batches > int(self.config["error_rates_n_batches"]):
break
predicted = self.diacritizer.diacritize_batch(batch)
all_predicted += predicted
all_orig += batch["original"]
if i > self.config["max_eval_batches"]:
break
tqdm.update()
summary_texts = []
orig_path = os.path.join(self.config_manager.prediction_dir, f"original.txt")
predicted_path = os.path.join(
self.config_manager.prediction_dir, f"predicted.txt"
)
table = wandb.Table(columns=["original", "predicted"])
with open(orig_path, "w", encoding="utf8") as file:
for sentence in all_orig:
file.write(f"{sentence}\n")
with open(predicted_path, "w", encoding="utf8") as file:
for sentence in all_predicted:
file.write(f"{sentence}\n")
for i in range(int(self.config["n_predicted_text_tensorboard"])):
if i > len(all_predicted):
break
summary_texts.append(
(f"eval-text/{i}", f"{ all_orig[i]} |-> {all_predicted[i]}")
)
if i < 10:
table.add_data(all_orig[i], all_predicted[i])
if log:
wandb.log({f"prediction_{self.global_step}": table}, commit=False)
results["DER"] = der.calculate_der_from_path(orig_path, predicted_path)
results["DER*"] = der.calculate_der_from_path(
orig_path, predicted_path, case_ending=False
)
results["WER"] = wer.calculate_wer_from_path(orig_path, predicted_path)
results["WER*"] = wer.calculate_wer_from_path(
orig_path, predicted_path, case_ending=False
)
if log:
wandb.log(results)
tqdm.reset()
return results, summary_texts
def run(self):
scaler = torch.cuda.amp.GradScaler()
train_iterator, _, validation_iterator = load_iterators(self.config_manager)
print("data loaded")
print("----------------------------------------------------------")
tqdm_eval = trange(0, len(validation_iterator), leave=True)
tqdm_error_rates = trange(0, len(validation_iterator), leave=True)
tqdm_eval.set_description("Eval")
tqdm_error_rates.set_description("WER/DER : ")
tqdm = trange(self.global_step, self.config["max_steps"] + 1, leave=True)
for batch_inputs in repeater(train_iterator):
tqdm.set_description(f"Global Step {self.global_step}")
if self.config["use_decay"]:
self.lr = self.adjust_learning_rate(
self.optimizer, global_step=self.global_step
)
self.optimizer.zero_grad()
if self.device == "cuda" and self.config["use_mixed_precision"]:
with autocast():
step_results = self.run_one_step(batch_inputs)
scaler.scale(step_results["loss"]).backward()
scaler.unscale_(self.optimizer)
if self.config.get("CLIP"):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config["CLIP"]
)
scaler.step(self.optimizer)
scaler.update()
else:
step_results = self.run_one_step(batch_inputs)
loss = step_results["loss"]
loss.backward()
if self.config.get("CLIP"):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config["CLIP"]
)
self.optimizer.step()
self.losses.append(step_results["loss"].item())
wandb.log({"train_loss": step_results["loss"].item()})
self.print_losses(step_results, tqdm)
self.summary_manager.add_scalar(
"meta/learning_rate", self.lr, global_step=self.global_step
)
if self.global_step % self.config["model_save_frequency"] == 0:
torch.save(
{
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
os.path.join(
self.config_manager.models_dir,
f"{self.global_step}-snapshot.pt",
),
)
if self.global_step % self.config["evaluate_frequency"] == 0:
loss, acc = self.evaluate(validation_iterator, tqdm_eval)
self.summary_manager.add_scalar(
"evaluate/loss", loss, global_step=self.global_step
)
self.summary_manager.add_scalar(
"evaluate/acc", acc, global_step=self.global_step
)
tqdm.display(
f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}", pos=8
)
self.model.train()
if (
self.global_step % self.config["evaluate_with_error_rates_frequency"]
== 0
):
error_rates, summery_texts = self.evaluate_with_error_rates(
validation_iterator, tqdm_error_rates
)
if error_rates:
WER = error_rates["WER"]
DER = error_rates["DER"]
DER1 = error_rates["DER*"]
WER1 = error_rates["WER*"]
self.summary_manager.add_scalar(
"error_rates/WER",
WER / 100,
global_step=self.global_step,
)
self.summary_manager.add_scalar(
"error_rates/DER",
DER / 100,
global_step=self.global_step,
)
self.summary_manager.add_scalar(
"error_rates/DER*",
DER1 / 100,
global_step=self.global_step,
)
self.summary_manager.add_scalar(
"error_rates/WER*",
WER1 / 100,
global_step=self.global_step,
)
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}"
tqdm.display(f"WER/DER {self.global_step}: {error_rates}", pos=9)
for tag, text in summery_texts:
self.summary_manager.add_text(tag, text)
self.model.train()
if self.global_step % self.config["train_plotting_frequency"] == 0:
self.plot_attention(step_results)
self.report(step_results, tqdm)
self.global_step += 1
if self.global_step > self.config["max_steps"]:
print("Training Done.")
return
tqdm.update()
def run_one_step(self, batch_inputs: Dict[str, torch.Tensor]):
batch_inputs["src"] = batch_inputs["src"].to(self.device)
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
batch_inputs["target"] = batch_inputs["target"].to(self.device)
outputs = self.model(
src=batch_inputs["src"],
target=batch_inputs["target"],
lengths=batch_inputs["lengths"],
)
predictions = outputs["diacritics"].contiguous()
targets = batch_inputs["target"].contiguous()
predictions = predictions.view(-1, predictions.shape[-1])
targets = targets.view(-1)
loss = self.criterion(predictions.to(self.device), targets.to(self.device))
outputs.update({"loss": loss})
return outputs
def predict(self, iterator):
pass
def load_model(self, model_path: str = None, load_optimizer: bool = True):
with open(
self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w"
) as file:
file.write(str(self.model))
if model_path is None:
last_model_path = self.config_manager.get_last_model_path()
if last_model_path is None:
self.global_step = 1
return
else:
last_model_path = model_path
print(f"loading from {last_model_path}")
saved_model = torch.load(last_model_path)
self.model.load_state_dict(saved_model["model_state_dict"])
if load_optimizer:
self.optimizer.load_state_dict(saved_model["optimizer_state_dict"])
self.global_step = saved_model["global_step"] + 1
def get_optimizer(self):
if self.config["optimizer"] == OptimizerType.Adam:
optimizer = optim.Adam(
self.model.parameters(),
lr=self.config["learning_rate"],
betas=(self.config["adam_beta1"], self.config["adam_beta2"]),
weight_decay=self.config["weight_decay"],
)
elif self.config["optimizer"] == OptimizerType.SGD:
optimizer = optim.SGD(
self.model.parameters(), lr=self.config["learning_rate"], momentum=0.9
)
else:
raise ValueError("Optimizer option is not valid")
return optimizer
def get_learning_rate(self):
return LearningRateDecay(
lr=self.config["learning_rate"],
warmup_steps=self.config.get("warmup_steps", 4000.0),
)
def adjust_learning_rate(self, optimizer, global_step):
learning_rate = self.get_learning_rate()(global_step=global_step)
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
return learning_rate
def plot_attention(self, results):
pass
def report(self, results, tqdm):
pass
class Seq2SeqTrainer(GeneralTrainer):
def plot_attention(self, results):
plot_alignment(
results["attention"][0],
str(self.config_manager.plot_dir),
self.global_step,
)
self.summary_manager.add_image(
"Train/attention",
results["attention"][0].unsqueeze(0),
global_step=self.global_step,
)
class GPTTrainer(GeneralTrainer):
pass
class CBHGTrainer(GeneralTrainer):
pass