"""Finetuning methods.""" import logging import os import torch from collections import OrderedDict from espnet.asr.asr_utils import get_model_conf from espnet.asr.asr_utils import torch_load from espnet.nets.asr_interface import ASRInterface from espnet.nets.mt_interface import MTInterface from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load from espnet.nets.tts_interface import TTSInterface from espnet.utils.dynamic_import import dynamic_import def freeze_modules(model, modules): """Freeze model parameters according to modules list. Args: model (torch.nn.Module): main model to update modules (list): specified module list for freezing Return: model (torch.nn.Module): updated model model_params (filter): filtered model parameters """ for mod, param in model.named_parameters(): if any(mod.startswith(m) for m in modules): logging.info(f"freezing {mod}, it will not be updated.") param.requires_grad = False model_params = filter(lambda x: x.requires_grad, model.parameters()) return model, model_params def transfer_verification(model_state_dict, partial_state_dict, modules): """Verify tuples (key, shape) for input model modules match specified modules. Args: model_state_dict (OrderedDict): the initial model state_dict partial_state_dict (OrderedDict): the trained model state_dict modules (list): specified module list for transfer Return: (boolean): allow transfer """ modules_model = [] partial_modules = [] for key_p, value_p in partial_state_dict.items(): if any(key_p.startswith(m) for m in modules): partial_modules += [(key_p, value_p.shape)] for key_m, value_m in model_state_dict.items(): if any(key_m.startswith(m) for m in modules): modules_model += [(key_m, value_m.shape)] len_match = len(modules_model) == len(partial_modules) module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted( partial_modules, key=lambda x: (x[0], x[1]) ) return len_match and module_match def get_partial_state_dict(model_state_dict, modules): """Create state_dict with specified modules matching input model modules. Note that get_partial_lm_state_dict is used if a LM specified. Args: model_state_dict (OrderedDict): trained model state_dict modules (list): specified module list for transfer Return: new_state_dict (OrderedDict): the updated state_dict """ new_state_dict = OrderedDict() for key, value in model_state_dict.items(): if any(key.startswith(m) for m in modules): new_state_dict[key] = value return new_state_dict def get_lm_state_dict(lm_state_dict): """Create compatible ASR decoder state dict from LM state dict. Args: lm_state_dict (OrderedDict): pre-trained LM state_dict Return: new_state_dict (OrderedDict): LM state_dict with updated keys """ new_state_dict = OrderedDict() for key, value in list(lm_state_dict.items()): if key == "predictor.embed.weight": new_state_dict["dec.embed.weight"] = value elif key.startswith("predictor.rnn."): _split = key.split(".") new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0" new_state_dict[new_key] = value return new_state_dict def filter_modules(model_state_dict, modules): """Filter non-matched modules in module_state_dict. Args: model_state_dict (OrderedDict): trained model state_dict modules (list): specified module list for transfer Return: new_mods (list): the update module list """ new_mods = [] incorrect_mods = [] mods_model = list(model_state_dict.keys()) for mod in modules: if any(key.startswith(mod) for key in mods_model): new_mods += [mod] else: incorrect_mods += [mod] if incorrect_mods: logging.warning( "module(s) %s don't match or (partially match) " "available modules in model.", incorrect_mods, ) logging.warning("for information, the existing modules in model are:") logging.warning("%s", mods_model) return new_mods def load_trained_model(model_path, training=True): """Load the trained model for recognition. Args: model_path (str): Path to model.***.best """ idim, odim, train_args = get_model_conf( model_path, os.path.join(os.path.dirname(model_path), "model.json") ) logging.warning("reading model parameters from " + model_path) if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" # CTC Loss is not needed, default to builtin to prevent import errors if hasattr(train_args, "ctc_type"): train_args.ctc_type = "builtin" model_class = dynamic_import(model_module) if "transducer" in model_module: model = model_class(idim, odim, train_args, training=training) custom_torch_load(model_path, model, training=training) else: model = model_class(idim, odim, train_args) torch_load(model_path, model) return model, train_args def get_trained_model_state_dict(model_path): """Extract the trained model state dict for pre-initialization. Args: model_path (str): Path to model.***.best Return: model.state_dict() (OrderedDict): the loaded model state_dict (bool): Boolean defining whether the model is an LM """ conf_path = os.path.join(os.path.dirname(model_path), "model.json") if "rnnlm" in model_path: logging.warning("reading model parameters from %s", model_path) return get_lm_state_dict(torch.load(model_path)) idim, odim, args = get_model_conf(model_path, conf_path) logging.warning("reading model parameters from " + model_path) if hasattr(args, "model_module"): model_module = args.model_module else: model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" model_class = dynamic_import(model_module) model = model_class(idim, odim, args) torch_load(model_path, model) assert ( isinstance(model, MTInterface) or isinstance(model, ASRInterface) or isinstance(model, TTSInterface) ) return model.state_dict() def load_trained_modules(idim, odim, args, interface=ASRInterface): """Load model encoder or/and decoder modules with ESPNET pre-trained model(s). Args: idim (int): initial input dimension. odim (int): initial output dimension. args (Namespace): The initial model arguments. interface (Interface): ASRInterface or STInterface or TTSInterface. Return: model (torch.nn.Module): The model with pretrained modules. """ def print_new_keys(state_dict, modules, model_path): logging.warning("loading %s from model: %s", modules, model_path) for k in state_dict.keys(): logging.warning("override %s" % k) enc_model_path = args.enc_init dec_model_path = args.dec_init enc_modules = args.enc_init_mods dec_modules = args.dec_init_mods model_class = dynamic_import(args.model_module) main_model = model_class(idim, odim, args) assert isinstance(main_model, interface) main_state_dict = main_model.state_dict() logging.warning("model(s) found for pre-initialization") for model_path, modules in [ (enc_model_path, enc_modules), (dec_model_path, dec_modules), ]: if model_path is not None: if os.path.isfile(model_path): model_state_dict = get_trained_model_state_dict(model_path) modules = filter_modules(model_state_dict, modules) partial_state_dict = get_partial_state_dict(model_state_dict, modules) if partial_state_dict: if transfer_verification( main_state_dict, partial_state_dict, modules ): print_new_keys(partial_state_dict, modules, model_path) main_state_dict.update(partial_state_dict) else: logging.warning( f"modules {modules} in model {model_path} " f"don't match your training config", ) else: logging.warning("model was not found : %s", model_path) main_model.load_state_dict(main_state_dict) return main_model