Spaces:
Running
Running
chore: file not needed
Browse filesFormer-commit-id: 01a923196312ced88a2f0ca2010e793c26c84855
dalle_mini/vqgan_jax/convert_pt_model_to_jax.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
import jax.numpy as jnp
|
4 |
-
from flax.traverse_util import flatten_dict, unflatten_dict
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
from modeling_flax_vqgan import VQModel
|
9 |
-
from configuration_vqgan import VQGANConfig
|
10 |
-
|
11 |
-
|
12 |
-
regex = r"\w+[.]\d+"
|
13 |
-
|
14 |
-
|
15 |
-
def rename_key(key):
|
16 |
-
pats = re.findall(regex, key)
|
17 |
-
for pat in pats:
|
18 |
-
key = key.replace(pat, "_".join(pat.split(".")))
|
19 |
-
return key
|
20 |
-
|
21 |
-
|
22 |
-
# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
|
23 |
-
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
24 |
-
# convert pytorch tensor to numpy
|
25 |
-
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
26 |
-
|
27 |
-
random_flax_state_dict = flatten_dict(flax_model.params)
|
28 |
-
flax_state_dict = {}
|
29 |
-
|
30 |
-
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
31 |
-
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
32 |
-
)
|
33 |
-
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
34 |
-
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
35 |
-
)
|
36 |
-
|
37 |
-
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
38 |
-
for pt_key, pt_tensor in pt_state_dict.items():
|
39 |
-
pt_tuple_key = tuple(pt_key.split("."))
|
40 |
-
|
41 |
-
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
42 |
-
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
43 |
-
|
44 |
-
if remove_base_model_prefix and has_base_model_prefix:
|
45 |
-
pt_tuple_key = pt_tuple_key[1:]
|
46 |
-
elif add_base_model_prefix and require_base_model_prefix:
|
47 |
-
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
48 |
-
|
49 |
-
# Correctly rename weight parameters
|
50 |
-
if (
|
51 |
-
"norm" in pt_key
|
52 |
-
and (pt_tuple_key[-1] == "bias")
|
53 |
-
and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
|
54 |
-
):
|
55 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
56 |
-
elif (
|
57 |
-
"norm" in pt_key
|
58 |
-
and (pt_tuple_key[-1] == "bias")
|
59 |
-
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
60 |
-
):
|
61 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
62 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
63 |
-
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
64 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
65 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
66 |
-
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
67 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
68 |
-
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
|
69 |
-
# conv layer
|
70 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
71 |
-
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
72 |
-
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
73 |
-
# linear layer
|
74 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
75 |
-
pt_tensor = pt_tensor.T
|
76 |
-
elif pt_tuple_key[-1] == "gamma":
|
77 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
78 |
-
elif pt_tuple_key[-1] == "beta":
|
79 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
80 |
-
|
81 |
-
if pt_tuple_key in random_flax_state_dict:
|
82 |
-
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
83 |
-
raise ValueError(
|
84 |
-
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
85 |
-
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
86 |
-
)
|
87 |
-
|
88 |
-
# also add unexpected weight so that warning is thrown
|
89 |
-
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
90 |
-
|
91 |
-
return unflatten_dict(flax_state_dict)
|
92 |
-
|
93 |
-
|
94 |
-
def convert_model(config_path, pt_state_dict_path, save_path):
|
95 |
-
config = VQGANConfig.from_pretrained(config_path)
|
96 |
-
model = VQModel(config)
|
97 |
-
|
98 |
-
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
|
99 |
-
keys = list(state_dict.keys())
|
100 |
-
for key in keys:
|
101 |
-
if key.startswith("loss"):
|
102 |
-
state_dict.pop(key)
|
103 |
-
continue
|
104 |
-
renamed_key = rename_key(key)
|
105 |
-
state_dict[renamed_key] = state_dict.pop(key)
|
106 |
-
|
107 |
-
state = convert_pytorch_state_dict_to_flax(state_dict, model)
|
108 |
-
model.params = unflatten_dict(state)
|
109 |
-
model.save_pretrained(save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|