|
from enum import Enum |
|
import os |
|
from pathlib import Path |
|
import shutil |
|
import subprocess |
|
from typing import Any, Dict |
|
|
|
import ruamel.yaml |
|
import torch |
|
|
|
from poetry_diacritizer.models.baseline import BaseLineModel |
|
from poetry_diacritizer.models.cbhg import CBHGModel |
|
from poetry_diacritizer.models.gpt import GPTModel |
|
from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq |
|
from poetry_diacritizer.models.tacotron_based import ( |
|
Decoder as TacotronDecoder, |
|
Encoder as TacotronEncoder, |
|
Tacotron, |
|
) |
|
|
|
from poetry_diacritizer.options import AttentionType, LossType, OptimizerType |
|
from poetry_diacritizer.util.text_encoders import ( |
|
ArabicEncoderWithStartSymbol, |
|
BasicArabicEncoder, |
|
TextEncoder, |
|
) |
|
|
|
|
|
class ConfigManager: |
|
"""Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager""" |
|
|
|
def __init__(self, config_path: str, model_kind: str): |
|
available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"] |
|
if model_kind not in available_models: |
|
raise TypeError(f"model_kind must be in {available_models}") |
|
self.config_path = Path(config_path) |
|
self.model_kind = model_kind |
|
self.yaml = ruamel.yaml.YAML() |
|
self.config: Dict[str, Any] = self._load_config() |
|
self.git_hash = self._get_git_hash() |
|
self.session_name = ".".join( |
|
[ |
|
self.config["data_type"], |
|
self.config["session_name"], |
|
f"{model_kind}", |
|
] |
|
) |
|
|
|
self.data_dir = Path( |
|
os.path.join(self.config["data_directory"], self.config["data_type"]) |
|
) |
|
self.base_dir = Path( |
|
os.path.join(self.config["log_directory"], self.session_name) |
|
) |
|
self.log_dir = Path(os.path.join(self.base_dir, "logs")) |
|
self.prediction_dir = Path(os.path.join(self.base_dir, "predictions")) |
|
self.plot_dir = Path(os.path.join(self.base_dir, "plots")) |
|
self.models_dir = Path(os.path.join(self.base_dir, "models")) |
|
if "sp_model_path" in self.config: |
|
self.sp_model_path = self.config["sp_model_path"] |
|
else: |
|
self.sp_model_path = None |
|
self.text_encoder: TextEncoder = self.get_text_encoder() |
|
self.config["len_input_symbols"] = len(self.text_encoder.input_symbols) |
|
self.config["len_target_symbols"] = len(self.text_encoder.target_symbols) |
|
if self.model_kind in ["seq2seq", "tacotron_based"]: |
|
self.config["attention_type"] = AttentionType[self.config["attention_type"]] |
|
self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]] |
|
|
|
def _load_config(self): |
|
with open(self.config_path, "rb") as model_yaml: |
|
_config = self.yaml.load(model_yaml) |
|
return _config |
|
|
|
@staticmethod |
|
def _get_git_hash(): |
|
try: |
|
return ( |
|
subprocess.check_output(["git", "describe", "--always"]) |
|
.strip() |
|
.decode() |
|
) |
|
except Exception as e: |
|
print(f"WARNING: could not retrieve git hash. {e}") |
|
|
|
def _check_hash(self): |
|
try: |
|
git_hash = ( |
|
subprocess.check_output(["git", "describe", "--always"]) |
|
.strip() |
|
.decode() |
|
) |
|
if self.config["git_hash"] != git_hash: |
|
print( |
|
f"""WARNING: git hash mismatch. Current: {git_hash}. |
|
Config hash: {self.config['git_hash']}""" |
|
) |
|
except Exception as e: |
|
print(f"WARNING: could not check git hash. {e}") |
|
|
|
@staticmethod |
|
def _print_dict_values(values, key_name, level=0, tab_size=2): |
|
tab = level * tab_size * " " |
|
print(tab + "-", key_name, ":", values) |
|
|
|
def _print_dictionary(self, dictionary, recursion_level=0): |
|
for key in dictionary.keys(): |
|
if isinstance(key, dict): |
|
recursion_level += 1 |
|
self._print_dictionary(dictionary[key], recursion_level) |
|
else: |
|
self._print_dict_values( |
|
dictionary[key], key_name=key, level=recursion_level |
|
) |
|
|
|
def print_config(self): |
|
print("\nCONFIGURATION", self.session_name) |
|
self._print_dictionary(self.config) |
|
|
|
def update_config(self): |
|
self.config["git_hash"] = self._get_git_hash() |
|
|
|
def dump_config(self): |
|
self.update_config() |
|
_config = {} |
|
for key, val in self.config.items(): |
|
if isinstance(val, Enum): |
|
_config[key] = val.name |
|
else: |
|
_config[key] = val |
|
with open(self.base_dir / "config.yml", "w") as model_yaml: |
|
self.yaml.dump(_config, model_yaml) |
|
|
|
def create_remove_dirs( |
|
self, |
|
clear_dir: bool = False, |
|
clear_logs: bool = False, |
|
clear_weights: bool = False, |
|
clear_all: bool = False, |
|
): |
|
self.base_dir.mkdir(exist_ok=True, parents=True) |
|
self.plot_dir.mkdir(exist_ok=True) |
|
self.prediction_dir.mkdir(exist_ok=True) |
|
if clear_dir: |
|
delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])") |
|
if delete == "y": |
|
shutil.rmtree(self.log_dir, ignore_errors=True) |
|
shutil.rmtree(self.models_dir, ignore_errors=True) |
|
if clear_logs: |
|
delete = input(f"Delete {self.log_dir}? (y/[n])") |
|
if delete == "y": |
|
shutil.rmtree(self.log_dir, ignore_errors=True) |
|
if clear_weights: |
|
delete = input(f"Delete {self.models_dir}? (y/[n])") |
|
if delete == "y": |
|
shutil.rmtree(self.models_dir, ignore_errors=True) |
|
self.log_dir.mkdir(exist_ok=True) |
|
self.models_dir.mkdir(exist_ok=True) |
|
|
|
def get_last_model_path(self): |
|
""" |
|
Given a checkpoint, get the last save model name |
|
Args: |
|
checkpoint (str): the path where models are saved |
|
""" |
|
models = os.listdir(self.models_dir) |
|
models = [model for model in models if model[-3:] == ".pt"] |
|
if len(models) == 0: |
|
return None |
|
_max = max(int(m.split(".")[0].split("-")[0]) for m in models) |
|
model_name = f"{_max}-snapshot.pt" |
|
last_model_path = os.path.join(self.models_dir, model_name) |
|
|
|
return last_model_path |
|
|
|
def load_model(self, model_path: str = None): |
|
""" |
|
loading a model from path |
|
Args: |
|
checkpoint (str): the path to the model |
|
name (str): the name of the model, which is in the path |
|
model (Tacotron): the model to load its save state |
|
optimizer: the optimizer to load its saved state |
|
""" |
|
|
|
model = self.get_model() |
|
|
|
with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file: |
|
file.write(str(model)) |
|
|
|
if model_path is None: |
|
last_model_path = self.get_last_model_path() |
|
if last_model_path is None: |
|
return model, 1 |
|
else: |
|
last_model_path = model_path |
|
|
|
saved_model = torch.load(last_model_path) |
|
out = model.load_state_dict(saved_model["model_state_dict"]) |
|
print(out) |
|
global_step = saved_model["global_step"] + 1 |
|
return model, global_step |
|
|
|
def get_model(self, ignore_hash=False): |
|
if not ignore_hash: |
|
self._check_hash() |
|
if self.model_kind == "cbhg": |
|
return self.get_cbhg() |
|
|
|
elif self.model_kind == "seq2seq": |
|
return self.get_seq2seq() |
|
|
|
elif self.model_kind == "tacotron_based": |
|
return self.get_tacotron_based() |
|
|
|
elif self.model_kind == "baseline": |
|
return self.get_baseline() |
|
|
|
elif self.model_kind == "gpt": |
|
return self.get_gpt() |
|
|
|
def get_gpt(self): |
|
model = GPTModel( |
|
self.config["base_model_path"], |
|
freeze=self.config["freeze"], |
|
n_layer=self.config["n_layer"], |
|
use_lstm=self.config["use_lstm"], |
|
) |
|
return model |
|
|
|
def get_baseline(self): |
|
model = BaseLineModel( |
|
embedding_dim=self.config["embedding_dim"], |
|
inp_vocab_size=self.config["len_input_symbols"], |
|
targ_vocab_size=self.config["len_target_symbols"], |
|
layers_units=self.config["layers_units"], |
|
use_batch_norm=self.config["use_batch_norm"], |
|
) |
|
|
|
return model |
|
|
|
def get_cbhg(self): |
|
model = CBHGModel( |
|
embedding_dim=self.config["embedding_dim"], |
|
inp_vocab_size=self.config["len_input_symbols"], |
|
targ_vocab_size=self.config["len_target_symbols"], |
|
use_prenet=self.config["use_prenet"], |
|
prenet_sizes=self.config["prenet_sizes"], |
|
cbhg_gru_units=self.config["cbhg_gru_units"], |
|
cbhg_filters=self.config["cbhg_filters"], |
|
cbhg_projections=self.config["cbhg_projections"], |
|
post_cbhg_layers_units=self.config["post_cbhg_layers_units"], |
|
post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"], |
|
) |
|
|
|
return model |
|
|
|
def get_seq2seq(self): |
|
encoder = Seq2SeqEncoder( |
|
embedding_dim=self.config["encoder_embedding_dim"], |
|
inp_vocab_size=self.config["len_input_symbols"], |
|
layers_units=self.config["encoder_units"], |
|
use_batch_norm=self.config["use_batch_norm"], |
|
) |
|
|
|
decoder = TacotronDecoder( |
|
self.config["len_target_symbols"], |
|
start_symbol_id=self.text_encoder.start_symbol_id, |
|
embedding_dim=self.config["decoder_embedding_dim"], |
|
encoder_dim=self.config["encoder_dim"], |
|
decoder_units=self.config["decoder_units"], |
|
decoder_layers=self.config["decoder_layers"], |
|
attention_type=self.config["attention_type"], |
|
attention_units=self.config["attention_units"], |
|
is_attention_accumulative=self.config["is_attention_accumulative"], |
|
use_prenet=self.config["use_decoder_prenet"], |
|
prenet_depth=self.config["decoder_prenet_depth"], |
|
teacher_forcing_probability=self.config["teacher_forcing_probability"], |
|
) |
|
|
|
model = Tacotron(encoder=encoder, decoder=decoder) |
|
|
|
return model |
|
|
|
def get_tacotron_based(self): |
|
encoder = TacotronEncoder( |
|
embedding_dim=self.config["encoder_embedding_dim"], |
|
inp_vocab_size=self.config["len_input_symbols"], |
|
prenet_sizes=self.config["prenet_sizes"], |
|
use_prenet=self.config["use_encoder_prenet"], |
|
cbhg_gru_units=self.config["cbhg_gru_units"], |
|
cbhg_filters=self.config["cbhg_filters"], |
|
cbhg_projections=self.config["cbhg_projections"], |
|
) |
|
|
|
decoder = TacotronDecoder( |
|
self.config["len_target_symbols"], |
|
start_symbol_id=self.text_encoder.start_symbol_id, |
|
embedding_dim=self.config["decoder_embedding_dim"], |
|
encoder_dim=self.config["encoder_dim"], |
|
decoder_units=self.config["decoder_units"], |
|
decoder_layers=self.config["decoder_layers"], |
|
attention_type=self.config["attention_type"], |
|
attention_units=self.config["attention_units"], |
|
is_attention_accumulative=self.config["is_attention_accumulative"], |
|
use_prenet=self.config["use_decoder_prenet"], |
|
prenet_depth=self.config["decoder_prenet_depth"], |
|
teacher_forcing_probability=self.config["teacher_forcing_probability"], |
|
) |
|
|
|
model = Tacotron(encoder=encoder, decoder=decoder) |
|
|
|
return model |
|
|
|
def get_text_encoder(self): |
|
"""Getting the class of TextEncoder from config""" |
|
if self.config["text_cleaner"] not in [ |
|
"basic_cleaners", |
|
"valid_arabic_cleaners", |
|
None, |
|
]: |
|
raise Exception(f"cleaner is not known {self.config['text_cleaner']}") |
|
|
|
if self.config["text_encoder"] == "BasicArabicEncoder": |
|
text_encoder = BasicArabicEncoder( |
|
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path |
|
) |
|
elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol": |
|
text_encoder = ArabicEncoderWithStartSymbol( |
|
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path |
|
) |
|
else: |
|
raise Exception( |
|
f"the text encoder is not found {self.config['text_encoder']}" |
|
) |
|
|
|
return text_encoder |
|
|
|
def get_loss_type(self): |
|
try: |
|
loss_type = LossType[self.config["loss_type"]] |
|
except: |
|
raise Exception(f"The loss type is not correct {self.config['loss_type']}") |
|
return loss_type |
|
|
|
|
|
if __name__ == "__main__": |
|
config_path = "config/tacotron-base-config.yml" |
|
model_kind = "tacotron" |
|
config = ConfigManager(config_path=config_path, model_kind=model_kind) |
|
|