|
import os |
|
from typing import Dict |
|
|
|
from diacritization_evaluation import der, wer |
|
import torch |
|
from torch import nn |
|
from torch import optim |
|
from torch.cuda.amp import autocast |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
from tqdm.notebook import tqdm |
|
from tqdm import trange |
|
from diacritization_evaluation import util |
|
|
|
from .config_manager import ConfigManager |
|
from .dataset import load_iterators |
|
from .diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer |
|
from .options import OptimizerType |
|
import gdown |
|
|
|
class Trainer: |
|
def run(self): |
|
raise NotImplementedError |
|
|
|
|
|
class GeneralTrainer(Trainer): |
|
def __init__(self, config_path: str, model_kind: str) -> None: |
|
self.config_path = config_path |
|
self.model_kind = model_kind |
|
self.config_manager = ConfigManager( |
|
config_path=config_path, model_kind=model_kind |
|
) |
|
self.config = self.config_manager.config |
|
self.losses = [] |
|
self.lr = 0 |
|
self.pad_idx = 0 |
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) |
|
self.set_device() |
|
|
|
self.config_manager.create_remove_dirs() |
|
self.text_encoder = self.config_manager.text_encoder |
|
self.start_symbol_id = self.text_encoder.start_symbol_id |
|
self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir) |
|
|
|
self.model = self.config_manager.get_model() |
|
|
|
self.optimizer = self.get_optimizer() |
|
self.model = self.model.to(self.device) |
|
|
|
self.load_model(model_path=self.config.get("train_resume_model_path")) |
|
self.load_diacritizer() |
|
|
|
self.initialize_model() |
|
|
|
|
|
def set_device(self): |
|
if self.config.get("device"): |
|
self.device = self.config["device"] |
|
else: |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def load_diacritizer(self): |
|
if self.model_kind in ["cbhg", "baseline"]: |
|
self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind) |
|
elif self.model_kind in ["seq2seq", "tacotron_based"]: |
|
self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind) |
|
|
|
def initialize_model(self): |
|
if self.global_step > 1: |
|
return |
|
if self.model_kind == "transformer": |
|
print("Initializing using xavier_uniform_") |
|
self.model.apply(initialize_weights) |
|
|
|
|
|
def load_model(self, model_path: str = None, load_optimizer: bool = True): |
|
with open( |
|
self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w" |
|
) as file: |
|
file.write(str(self.model)) |
|
|
|
if model_path is None: |
|
last_model_path = self.config_manager.get_last_model_path() |
|
if last_model_path is None: |
|
self.global_step = 1 |
|
return |
|
else: |
|
last_model_path = model_path |
|
|
|
print(f"loading from {last_model_path}") |
|
saved_model = torch.load(last_model_path, torch.device(self.config.get("device"))) |
|
self.model.load_state_dict(saved_model["model_state_dict"]) |
|
if load_optimizer: |
|
self.optimizer.load_state_dict(saved_model["optimizer_state_dict"]) |
|
self.global_step = saved_model["global_step"] + 1 |
|
|
|
class DiacritizationTester(GeneralTrainer): |
|
def __init__(self, config_path: str, model_kind: str, model_path: str) -> None: |
|
|
|
|
|
|
|
|
|
|
|
self.config_path = config_path |
|
self.model_kind = model_kind |
|
self.config_manager = ConfigManager( |
|
config_path=config_path, model_kind=model_kind |
|
) |
|
self.config = self.config_manager.config |
|
|
|
self.pad_idx = 0 |
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) |
|
self.set_device() |
|
|
|
self.text_encoder = self.config_manager.text_encoder |
|
self.start_symbol_id = self.text_encoder.start_symbol_id |
|
|
|
self.model = self.config_manager.get_model() |
|
|
|
self.model = self.model.to(self.device) |
|
self.load_model(model_path=model_path, load_optimizer=False) |
|
self.load_diacritizer() |
|
self.diacritizer.set_model(self.model) |
|
self.initialize_model() |
|
|
|
def collate_fn(self, data): |
|
""" |
|
Padding the input and output sequences |
|
""" |
|
|
|
def merge(sequences): |
|
lengths = [len(seq) for seq in sequences] |
|
padded_seqs = torch.zeros(len(sequences), max(lengths)).long() |
|
for i, seq in enumerate(sequences): |
|
end = lengths[i] |
|
padded_seqs[i, :end] = seq[:end] |
|
return padded_seqs, lengths |
|
|
|
data.sort(key=lambda x: len(x[0]), reverse=True) |
|
|
|
|
|
src_seqs, trg_seqs, original = zip(*data) |
|
|
|
|
|
src_seqs, src_lengths = merge(src_seqs) |
|
trg_seqs, trg_lengths = merge(trg_seqs) |
|
|
|
batch = { |
|
"original": original, |
|
"src": src_seqs, |
|
"target": trg_seqs, |
|
"lengths": torch.LongTensor(src_lengths), |
|
} |
|
return batch |
|
|
|
def get_batch(self, sentence): |
|
data = self.text_encoder.clean(sentence) |
|
text, inputs, diacritics = util.extract_haraqat(data) |
|
inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs))) |
|
diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics)) |
|
batch = self.collate_fn([(inputs, diacritics, text)]) |
|
return batch |
|
|
|
def infer(self, sentence): |
|
self.model.eval() |
|
batch = self.get_batch(sentence) |
|
predicted = self.diacritizer.diacritize_batch(batch) |
|
return predicted[0] |
|
|