|
from mmengine.runner.checkpoint import CheckpointLoader |
|
|
|
|
|
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): |
|
"""Load partial pretrained model with specific prefix. |
|
|
|
Args: |
|
prefix (str): The prefix of sub-module. |
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
details. |
|
map_location (str | None): Same as :func:`torch.load`. |
|
Defaults to None. |
|
logger: logger |
|
|
|
Returns: |
|
dict or OrderedDict: The loaded checkpoint. |
|
""" |
|
|
|
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) |
|
|
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
else: |
|
state_dict = checkpoint |
|
if not prefix: |
|
return state_dict |
|
if not prefix.endswith('.'): |
|
prefix += '.' |
|
prefix_len = len(prefix) |
|
|
|
state_dict = { |
|
k[prefix_len:]: v |
|
for k, v in state_dict.items() if k.startswith(prefix) |
|
} |
|
|
|
assert state_dict, f'{prefix} is not in the pretrained model' |
|
return state_dict |
|
|