RegionSpot / tools /re_save_ckpt.py
bklg's picture
Upload 37 files
dccc960
raw
history blame
869 Bytes
import torch
pretrain_ckpt = './pretrained_model/model_final.pth'
checkpoint = torch.load(pretrain_ckpt, map_location='cpu')
# Remove specific keys from the top-level dictionary
top_level_keys_to_remove = ['trainer', 'iteration']
for key in top_level_keys_to_remove:
if key in checkpoint:
del checkpoint[key]
# Remove keys that start with 'clip_model' and 'sam' from the checkpoint's 'model' dictionary
model_keys_to_remove = ['model.clip_model', 'model.sam']
for key in list(checkpoint['model'].keys()): # Use list to copy keys
if any(key.startswith(to_remove) for to_remove in model_keys_to_remove):
print(key)
del checkpoint['model'][key]
# Save the modified checkpoint back to a file
modified_ckpt_path = './pretrained_model/model_final_modified.pth'
torch.save(checkpoint, modified_ckpt_path)
print(checkpoint['model'].keys())