Spaces:
Running
Running
File size: 4,601 Bytes
150ed18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import re
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
import torch
from modeling_flax_vqgan import VQModel
from configuration_vqgan import VQGANConfig
regex = r"\w+[.]\d+"
def rename_key(key):
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key
# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
random_flax_state_dict = flatten_dict(flax_model.params)
flax_state_dict = {}
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
if remove_base_model_prefix and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]
elif add_base_model_prefix and require_base_model_prefix:
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
# Correctly rename weight parameters
if (
"norm" in pt_key
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
):
pt_tensor = pt_tensor[None, None, None, :]
elif (
"norm" in pt_key
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
pt_tensor = pt_tensor[None, None, None, :]
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
pt_tensor = pt_tensor[None, None, None, :]
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
# conv layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
# linear layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma":
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
elif pt_tuple_key[-1] == "beta":
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key in random_flax_state_dict:
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
)
# also add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
return unflatten_dict(flax_state_dict)
def convert_model(config_path, pt_state_dict_path, save_path):
config = VQGANConfig.from_pretrained(config_path)
model = VQModel(config)
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for key in keys:
if key.startswith("loss"):
state_dict.pop(key)
continue
renamed_key = rename_key(key)
state_dict[renamed_key] = state_dict.pop(key)
state = convert_pytorch_state_dict_to_flax(state_dict, model)
model.params = unflatten_dict(state)
model.save_pretrained(save_path)
|