Spaces:
Sleeping
Sleeping
import argparse | |
import huggingface_hub | |
import k_diffusion as K | |
import torch | |
from diffusers import UNet2DConditionModel | |
UPSCALER_REPO = "pcuenq/k-upscaler" | |
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): | |
rv = { | |
# norm1 | |
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"], | |
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"], | |
# conv1 | |
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"], | |
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"], | |
# norm2 | |
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"], | |
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"], | |
# conv2 | |
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"], | |
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"], | |
} | |
if resnet.conv_shortcut is not None: | |
rv.update( | |
{ | |
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"], | |
} | |
) | |
return rv | |
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): | |
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0) | |
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0) | |
rv = { | |
# norm | |
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"], | |
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"], | |
# to_q | |
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1), | |
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q, | |
# to_k | |
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1), | |
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k, | |
# to_v | |
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1), | |
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v, | |
# to_out | |
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"] | |
.squeeze(-1) | |
.squeeze(-1), | |
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"], | |
} | |
return rv | |
def cross_attn_to_diffusers_checkpoint( | |
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix | |
): | |
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0) | |
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0) | |
rv = { | |
# norm2 (ada groupnorm) | |
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[ | |
f"{attention_prefix}.norm_dec.mapper.weight" | |
], | |
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[ | |
f"{attention_prefix}.norm_dec.mapper.bias" | |
], | |
# layernorm on encoder_hidden_state | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[ | |
f"{attention_prefix}.norm_enc.weight" | |
], | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[ | |
f"{attention_prefix}.norm_enc.bias" | |
], | |
# to_q | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[ | |
f"{attention_prefix}.q_proj.weight" | |
] | |
.squeeze(-1) | |
.squeeze(-1), | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[ | |
f"{attention_prefix}.q_proj.bias" | |
], | |
# to_k | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1), | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k, | |
# to_v | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1), | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v, | |
# to_out | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[ | |
f"{attention_prefix}.out_proj.weight" | |
] | |
.squeeze(-1) | |
.squeeze(-1), | |
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[ | |
f"{attention_prefix}.out_proj.bias" | |
], | |
} | |
return rv | |
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): | |
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks" | |
block_prefix = f"{block_prefix}.{block_idx}" | |
diffusers_checkpoint = {} | |
if not hasattr(block, "attentions"): | |
n = 1 # resnet only | |
elif not block.attentions[0].add_self_attention: | |
n = 2 # resnet -> cross-attention | |
else: | |
n = 3 # resnet -> self-attention -> cross-attention) | |
for resnet_idx, resnet in enumerate(block.resnets): | |
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" | |
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}" | |
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1 | |
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}" | |
diffusers_checkpoint.update( | |
resnet_to_diffusers_checkpoint( | |
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix | |
) | |
) | |
if hasattr(block, "attentions"): | |
for attention_idx, attention in enumerate(block.attentions): | |
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" | |
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 | |
self_attention_prefix = f"{block_prefix}.{idx}" | |
cross_attention_prefix = f"{block_prefix}.{idx }" | |
cross_attention_index = 1 if not attention.add_self_attention else 2 | |
idx = ( | |
n * attention_idx + cross_attention_index | |
if block_type == "up" | |
else n * attention_idx + cross_attention_index + 1 | |
) | |
cross_attention_prefix = f"{block_prefix}.{idx }" | |
diffusers_checkpoint.update( | |
cross_attn_to_diffusers_checkpoint( | |
checkpoint, | |
diffusers_attention_prefix=diffusers_attention_prefix, | |
diffusers_attention_index=2, | |
attention_prefix=cross_attention_prefix, | |
) | |
) | |
if attention.add_self_attention is True: | |
diffusers_checkpoint.update( | |
self_attn_to_diffusers_checkpoint( | |
checkpoint, | |
diffusers_attention_prefix=diffusers_attention_prefix, | |
attention_prefix=self_attention_prefix, | |
) | |
) | |
return diffusers_checkpoint | |
def unet_to_diffusers_checkpoint(model, checkpoint): | |
diffusers_checkpoint = {} | |
# pre-processing | |
diffusers_checkpoint.update( | |
{ | |
"conv_in.weight": checkpoint["inner_model.proj_in.weight"], | |
"conv_in.bias": checkpoint["inner_model.proj_in.bias"], | |
} | |
) | |
# timestep and class embedding | |
diffusers_checkpoint.update( | |
{ | |
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1), | |
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"], | |
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"], | |
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"], | |
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"], | |
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"], | |
} | |
) | |
# down_blocks | |
for down_block_idx, down_block in enumerate(model.down_blocks): | |
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down")) | |
# up_blocks | |
for up_block_idx, up_block in enumerate(model.up_blocks): | |
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up")) | |
# post-processing | |
diffusers_checkpoint.update( | |
{ | |
"conv_out.weight": checkpoint["inner_model.proj_out.weight"], | |
"conv_out.bias": checkpoint["inner_model.proj_out.bias"], | |
} | |
) | |
return diffusers_checkpoint | |
def unet_model_from_original_config(original_config): | |
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"] | |
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0) | |
block_out_channels = original_config["channels"] | |
assert ( | |
len(set(original_config["depths"])) == 1 | |
), "UNet2DConditionModel currently do not support blocks with different number of layers" | |
layers_per_block = original_config["depths"][0] | |
class_labels_dim = original_config["mapping_cond_dim"] | |
cross_attention_dim = original_config["cross_cond_dim"] | |
attn1_types = [] | |
attn2_types = [] | |
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]): | |
if s: | |
a1 = "self" | |
a2 = "cross" if c else None | |
elif c: | |
a1 = "cross" | |
a2 = None | |
else: | |
a1 = None | |
a2 = None | |
attn1_types.append(a1) | |
attn2_types.append(a2) | |
unet = UNet2DConditionModel( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"), | |
mid_block_type=None, | |
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"), | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
act_fn="gelu", | |
norm_num_groups=None, | |
cross_attention_dim=cross_attention_dim, | |
attention_head_dim=64, | |
time_cond_proj_dim=class_labels_dim, | |
resnet_time_scale_shift="scale_shift", | |
time_embedding_type="fourier", | |
timestep_post_act="gelu", | |
conv_in_kernel=1, | |
conv_out_kernel=1, | |
) | |
return unet | |
def main(args): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json") | |
orig_weights_path = huggingface_hub.hf_hub_download( | |
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth" | |
) | |
print(f"loading original model configuration from {orig_config_path}") | |
print(f"loading original model checkpoint from {orig_weights_path}") | |
print("converting to diffusers unet") | |
orig_config = K.config.load_config(open(orig_config_path))["model"] | |
model = unet_model_from_original_config(orig_config) | |
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"] | |
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint) | |
model.load_state_dict(converted_checkpoint, strict=True) | |
model.save_pretrained(args.dump_path) | |
print(f"saving converted unet model in {args.dump_path}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | |
args = parser.parse_args() | |
main(args) | |