Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
from safetensors.torch import save_file, load_file | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-f", "--file", type=str, | |
default="model.ckpt", help="path to model") | |
parser.add_argument("-p", "--precision", default="fp32", | |
help="precision fp32(full)/fp16/bf16") | |
parser.add_argument("-t", "--type", type=str, default="full", | |
help="convert types full/ema-only/no-ema") | |
parser.add_argument("-st", "--safe-tensors", action="store_true", | |
default=False, help="use safetensors model format") | |
cmds = parser.parse_args() | |
def conv_fp16(t: Tensor): | |
if not isinstance(t, Tensor): | |
return t | |
return t.half() | |
def conv_bf16(t: Tensor): | |
if not isinstance(t, Tensor): | |
return t | |
return t.bfloat16() | |
def conv_full(t): | |
return t | |
_g_precision_func = { | |
"full": conv_full, | |
"fp32": conv_full, | |
"half": conv_fp16, | |
"fp16": conv_fp16, | |
"bf16": conv_bf16, | |
} | |
def convert(path: str, conv_type: str): | |
ok = {} # {"state_dict": {}} | |
_hf = _g_precision_func[cmds.precision] | |
if path.endswith(".safetensors"): | |
m = load_file(path, device="cpu") | |
else: | |
m = torch.load(path, map_location="cpu") | |
state_dict = m["state_dict"] if "state_dict" in m else m | |
if conv_type == "ema-only" or conv_type == "prune": | |
for k in state_dict: | |
ema_k = "___" | |
try: | |
ema_k = "model_ema." + k[6:].replace(".", "") | |
except: | |
pass | |
if ema_k in state_dict: | |
ok[k] = _hf(state_dict[ema_k]) | |
print("ema: " + ema_k + " > " + k) | |
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: | |
ok[k] = _hf(state_dict[k]) | |
print(k) | |
else: | |
print("skipped: " + k) | |
elif conv_type == "no-ema": | |
for k, v in state_dict.items(): | |
if "model_ema" not in k: | |
ok[k] = _hf(v) | |
else: | |
for k, v in state_dict.items(): | |
ok[k] = _hf(v) | |
return ok | |
def main(): | |
model_name = ".".join(cmds.file.split(".")[:-1]) | |
converted = convert(cmds.file, cmds.type) | |
save_name = f"{model_name}-{cmds.type}-{cmds.precision}" | |
print("convert ok, saving model") | |
if cmds.safe_tensors: | |
save_file(converted, save_name + ".safetensors") | |
else: | |
torch.save({"state_dict": converted}, save_name + ".ckpt") | |
print("convert finish.") | |
if __name__ == "__main__": | |
main() | |