|
|
|
import argparse
|
|
from collections import OrderedDict
|
|
import torch
|
|
|
|
|
|
def convert(src: str, dst: str):
|
|
checkpoint = torch.load(src)
|
|
has_model = "model" in checkpoint.keys()
|
|
checkpoint = checkpoint["model"] if has_model else checkpoint
|
|
if "state_dict" in checkpoint.keys():
|
|
checkpoint = checkpoint["state_dict"]
|
|
out_cp = OrderedDict()
|
|
for k, v in checkpoint.items():
|
|
out_cp[".".join(["backbone", k])] = v
|
|
out_cp = {"model": out_cp} if has_model else out_cp
|
|
torch.save(out_cp, dst)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--src", "-s", type=str, required=True, help="Path to src weights.pth"
|
|
)
|
|
parser.add_argument(
|
|
"--dst", "-d", type=str, required=True, help="Path to dst weights.pth"
|
|
)
|
|
args = parser.parse_args()
|
|
convert(
|
|
args.src,
|
|
args.dst,
|
|
)
|
|
|