Spaces:
Sleeping
Sleeping
File size: 594 Bytes
854728f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
def load_matched_state_dict(model, state_dict, print_stats=True):
"""
Only loads weights that matched in key and shape. Ignore other weights.
"""
num_matched, num_total = 0, 0
curr_state_dict = model.state_dict()
for key in curr_state_dict.keys():
num_total += 1
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
curr_state_dict[key] = state_dict[key]
num_matched += 1
model.load_state_dict(curr_state_dict)
if print_stats:
print(f'Loaded state_dict: {num_matched}/{num_total} matched') |