File size: 348 Bytes
a166479
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import torch

model = torch.load('model_best_refcoco_0508.pth', map_location='cpu') 

print(model['model'].keys())

new_dict = {}
for k in model['model'].keys():
    if 'image_model' in k or 'language_model' in k or 'classifier' in k:
        new_dict[k] = model['model'][k]

#torch.save('gradio.pth', new_dict)
torch.save(new_dict, 'gradio.pth')