Spaces:
Build error
Build error
File size: 549 Bytes
1ed7deb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import sys
if __name__ == "__main__":
inpath = sys.argv[1]
outpath = sys.argv[2]
submodel = "cond_stage_model"
if len(sys.argv) > 3:
submodel = sys.argv[3]
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
sd = torch.load(inpath, map_location="cpu")
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
for k,v in sd["state_dict"].items()
if k.startswith("cond_stage_model"))}
torch.save(new_sd, outpath)
|