ViTGaze / scripts /convert_pth.py
yhsong's picture
initial commit
f9561b9 verified
raw
history blame contribute delete
991 Bytes
# Convert official model weights to format that d2 receives
import argparse
from collections import OrderedDict
import torch
def convert(src: str, dst: str):
checkpoint = torch.load(src)
has_model = "model" in checkpoint.keys()
checkpoint = checkpoint["model"] if has_model else checkpoint
if "state_dict" in checkpoint.keys():
checkpoint = checkpoint["state_dict"]
out_cp = OrderedDict()
for k, v in checkpoint.items():
out_cp[".".join(["backbone", k])] = v
out_cp = {"model": out_cp} if has_model else out_cp
torch.save(out_cp, dst)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--src", "-s", type=str, required=True, help="Path to src weights.pth"
)
parser.add_argument(
"--dst", "-d", type=str, required=True, help="Path to dst weights.pth"
)
args = parser.parse_args()
convert(
args.src,
args.dst,
)