Spaces:
Sleeping
Sleeping
import torch | |
pretrain_ckpt = './pretrained_model/model_final.pth' | |
checkpoint = torch.load(pretrain_ckpt, map_location='cpu') | |
# Remove specific keys from the top-level dictionary | |
top_level_keys_to_remove = ['trainer', 'iteration'] | |
for key in top_level_keys_to_remove: | |
if key in checkpoint: | |
del checkpoint[key] | |
# Remove keys that start with 'clip_model' and 'sam' from the checkpoint's 'model' dictionary | |
model_keys_to_remove = ['model.clip_model', 'model.sam'] | |
for key in list(checkpoint['model'].keys()): # Use list to copy keys | |
if any(key.startswith(to_remove) for to_remove in model_keys_to_remove): | |
print(key) | |
del checkpoint['model'][key] | |
# Save the modified checkpoint back to a file | |
modified_ckpt_path = './pretrained_model/model_final_modified.pth' | |
torch.save(checkpoint, modified_ckpt_path) | |
print(checkpoint['model'].keys()) | |