Tzktz's picture
Upload 7664 files
6fc683c verified
raw
history blame
4.61 kB
import torch
import logging
from transformers.modeling_utils import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME
logger = logging.getLogger(__name__)
def get_checkpoint_from_transformer_cache(
archive_file, pretrained_model_name_or_path, pretrained_model_archive_map,
cache_dir, force_download, proxies, resume_download,
):
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError:
if pretrained_model_name_or_path in pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file)
else:
msg = "Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to model weight files named one of {} but " \
"couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(pretrained_model_archive_map.keys()),
archive_file,
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
return torch.load(resolved_archive_file, map_location='cpu')
def hf_roberta_to_hf_bert(state_dict):
logger.info(" * Convert Huggingface RoBERTa format to Huggingface BERT format * ")
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if key == 'roberta.embeddings.position_embeddings.weight':
value = value[2:]
if key == 'roberta.embeddings.token_type_embeddings.weight':
continue
if key.startswith('roberta'):
key = 'bert.' + key[8:]
elif key.startswith('lm_head'):
if 'layer_norm' in key or 'dense' in key:
key = 'cls.predictions.transform.' + key[8:]
else:
key = 'cls.predictions.' + key[8:]
key = key.replace('layer_norm', 'LayerNorm')
new_state_dict[key] = value
return new_state_dict
def hf_distilbert_to_hf_bert(state_dict):
logger.info(" * Convert Huggingface DistilBERT format to Huggingface BERT format * ")
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if key == 'roberta.embeddings.position_embeddings.weight':
value = value[2:]
if key == 'roberta.embeddings.token_type_embeddings.weight':
continue
if key.startswith('roberta'):
key = 'bert.' + key[8:]
elif key.startswith('lm_head'):
if 'layer_norm' in key or 'dense' in key:
key = 'cls.predictions.transform.' + key[8:]
else:
key = 'cls.predictions.' + key[8:]
key = key.replace('layer_norm', 'LayerNorm')
new_state_dict[key] = value
return new_state_dict
def hf_bert_to_hf_bert(state_dict):
# NOTE: all cls states are used for prediction,
# we predict the index so omit all pretrained states for prediction.
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if key.startswith('cls'):
# NOTE: all cls states are used for prediction,
# we predict the index so omit all pretrained states for prediction.
continue
new_state_dict[key] = value
return new_state_dict
def hf_layoutlm_to_hf_bert(state_dict):
logger.info(" * Convert Huggingface LayoutLM format to Huggingface BERT format * ")
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if key.startswith('layoutlm'):
key = 'bert.' + key[9:]
elif key.startswith('cls'):
# NOTE: all cls states are used for prediction,
# we predict the index so omit all pretrained states for prediction.
continue
new_state_dict[key] = value
return new_state_dict
state_dict_convert = {
'bert': hf_bert_to_hf_bert,
'unilm': hf_bert_to_hf_bert,
'minilm': hf_bert_to_hf_bert,
'layoutlm': hf_layoutlm_to_hf_bert,
'roberta': hf_roberta_to_hf_bert,
'xlm-roberta': hf_roberta_to_hf_bert,
'distilbert': hf_distilbert_to_hf_bert,
}