import torch import logging from transformers.modeling_utils import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME logger = logging.getLogger(__name__) def get_checkpoint_from_transformer_cache( archive_file, pretrained_model_name_or_path, pretrained_model_archive_map, cache_dir, force_download, proxies, resume_download, ): try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download) except EnvironmentError: if pretrained_model_name_or_path in pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file) else: msg = "Model name '{}' was not found in model name list ({}). " \ "We assumed '{}' was a path or url to model weight files named one of {} but " \ "couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(pretrained_model_archive_map.keys()), archive_file, [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) return torch.load(resolved_archive_file, map_location='cpu') def hf_roberta_to_hf_bert(state_dict): logger.info(" * Convert Huggingface RoBERTa format to Huggingface BERT format * ") new_state_dict = {} for key in state_dict: value = state_dict[key] if key == 'roberta.embeddings.position_embeddings.weight': value = value[2:] if key == 'roberta.embeddings.token_type_embeddings.weight': continue if key.startswith('roberta'): key = 'bert.' + key[8:] elif key.startswith('lm_head'): if 'layer_norm' in key or 'dense' in key: key = 'cls.predictions.transform.' + key[8:] else: key = 'cls.predictions.' + key[8:] key = key.replace('layer_norm', 'LayerNorm') new_state_dict[key] = value return new_state_dict def hf_distilbert_to_hf_bert(state_dict): logger.info(" * Convert Huggingface DistilBERT format to Huggingface BERT format * ") new_state_dict = {} for key in state_dict: value = state_dict[key] if key == 'roberta.embeddings.position_embeddings.weight': value = value[2:] if key == 'roberta.embeddings.token_type_embeddings.weight': continue if key.startswith('roberta'): key = 'bert.' + key[8:] elif key.startswith('lm_head'): if 'layer_norm' in key or 'dense' in key: key = 'cls.predictions.transform.' + key[8:] else: key = 'cls.predictions.' + key[8:] key = key.replace('layer_norm', 'LayerNorm') new_state_dict[key] = value return new_state_dict def hf_bert_to_hf_bert(state_dict): # NOTE: all cls states are used for prediction, # we predict the index so omit all pretrained states for prediction. new_state_dict = {} for key in state_dict: value = state_dict[key] if key.startswith('cls'): # NOTE: all cls states are used for prediction, # we predict the index so omit all pretrained states for prediction. continue new_state_dict[key] = value return new_state_dict def hf_layoutlm_to_hf_bert(state_dict): logger.info(" * Convert Huggingface LayoutLM format to Huggingface BERT format * ") new_state_dict = {} for key in state_dict: value = state_dict[key] if key.startswith('layoutlm'): key = 'bert.' + key[9:] elif key.startswith('cls'): # NOTE: all cls states are used for prediction, # we predict the index so omit all pretrained states for prediction. continue new_state_dict[key] = value return new_state_dict state_dict_convert = { 'bert': hf_bert_to_hf_bert, 'unilm': hf_bert_to_hf_bert, 'minilm': hf_bert_to_hf_bert, 'layoutlm': hf_layoutlm_to_hf_bert, 'roberta': hf_roberta_to_hf_bert, 'xlm-roberta': hf_roberta_to_hf_bert, 'distilbert': hf_distilbert_to_hf_bert, }