Spaces:
Build error
Build error
import tensorflow as tf | |
import torch | |
from typing import Dict | |
from itertools import product | |
from keras_cv.models import stable_diffusion | |
def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int) -> Dict[str, torch.Tensor]: | |
"""Populates a Transformer block.""" | |
transformer_dict = dict() | |
if block_id is not None: | |
prefix = f"{up_down}_blocks.{block_id}" | |
else: | |
prefix = "mid_block" | |
# Norms. | |
for i in range(1, 4): | |
if i == 1: | |
norm = transformer_block.norm1 | |
elif i == 2: | |
norm = transformer_block.norm2 | |
elif i == 3: | |
norm = transformer_block.norm3 | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"] = torch.from_numpy(norm.get_weights()[0]) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"] = torch.from_numpy(norm.get_weights()[1]) | |
# Attentions. | |
for i in range(1, 3): | |
if i == 1: | |
attn = transformer_block.attn1 | |
else: | |
attn = transformer_block.attn2 | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"] = torch.from_numpy(attn.to_q.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"] = torch.from_numpy(attn.to_k.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"] = torch.from_numpy(attn.to_v.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"] = torch.from_numpy(attn.out_proj.get_weights()[1]) | |
# Dense. | |
for i in range(0, 3, 2): | |
if i == 0: | |
layer = transformer_block.geglu.dense | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
else: | |
layer = transformer_block.dense | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) | |
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
return transformer_dict | |
def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]: | |
"""Populates the state dict from the provided TensorFlow model | |
(applicable only for the UNet).""" | |
unet_state_dict = dict() | |
timstep_emb = 1 | |
padded_conv = 1 | |
up_block = 0 | |
up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2])) | |
up_res_block_flag = 0 | |
up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2])) | |
up_spatial_transformer_flag = 0 | |
for layer in tf_unet.layers: | |
# Timstep embedding. | |
if isinstance(layer, tf.keras.layers.Dense): | |
unet_state_dict[f"time_embedding.linear_{timstep_emb}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) | |
unet_state_dict[f"time_embedding.linear_{timstep_emb}.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
timstep_emb += 1 | |
# Padded convs (downsamplers). | |
elif isinstance(layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): | |
if padded_conv == 1: | |
# Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104 | |
unet_state_dict["conv_in.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict["conv_in.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
elif padded_conv in [2, 3, 4]: | |
unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
elif padded_conv == 5: | |
unet_state_dict["conv_out.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict["conv_out.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
padded_conv += 1 | |
# Upsamplers. | |
elif isinstance(layer, stable_diffusion.diffusion_model.Upsample): | |
conv = layer.conv | |
unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.weight"] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.bias"] = torch.from_numpy(conv.get_weights()[1]) | |
up_block += 1 | |
# Output norms. | |
elif isinstance(layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization): | |
unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(layer.get_weights()[0]) | |
unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(layer.get_weights()[1]) | |
# All ResBlocks. | |
elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock): | |
layer_name = layer.name | |
parts = layer_name.split("_") | |
# Down. | |
if len(parts) == 2 or int(parts[-1]) < 8: | |
entry_flow = layer.entry_flow | |
embedding_flow = layer.embedding_flow | |
exit_flow = layer.exit_flow | |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 | |
down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 | |
# Conv blocks. | |
first_conv_layer = entry_flow[-1] | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) | |
second_conv_layer = exit_flow[-1] | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) | |
# Residual blocks. | |
if hasattr(layer, "residual_projection"): | |
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): | |
residual = layer.residual_projection | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) | |
# Timestep embedding. | |
embedding_proj = embedding_flow[-1] | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) | |
# Norms. | |
first_group_norm = entry_flow[0] | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) | |
second_group_norm = exit_flow[0] | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) | |
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) | |
# Middle. | |
elif int(parts[-1]) == 8 or int(parts[-1]) == 9: | |
entry_flow = layer.entry_flow | |
embedding_flow = layer.embedding_flow | |
exit_flow = layer.exit_flow | |
mid_resnet_id = int(parts[-1]) % 2 | |
# Conv blocks. | |
first_conv_layer = entry_flow[-1] | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) | |
second_conv_layer = exit_flow[-1] | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) | |
# Residual blocks. | |
if hasattr(layer, "residual_projection"): | |
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): | |
residual = layer.residual_projection | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) | |
# Timestep embedding. | |
embedding_proj = embedding_flow[-1] | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) | |
# Norms. | |
first_group_norm = entry_flow[0] | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) | |
second_group_norm = exit_flow[0] | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) | |
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) | |
# Up. | |
elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks): | |
entry_flow = layer.entry_flow | |
embedding_flow = layer.embedding_flow | |
exit_flow = layer.exit_flow | |
up_res_block = up_res_blocks[up_res_block_flag] | |
up_block_id = up_res_block[0] | |
up_resnet_id = up_res_block[1] | |
# Conv blocks. | |
first_conv_layer = entry_flow[-1] | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) | |
second_conv_layer = exit_flow[-1] | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) | |
# Residual blocks. | |
if hasattr(layer, "residual_projection"): | |
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): | |
residual = layer.residual_projection | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) | |
# Timestep embedding. | |
embedding_proj = embedding_flow[-1] | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) | |
# Norms. | |
first_group_norm = entry_flow[0] | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) | |
second_group_norm = exit_flow[0] | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) | |
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) | |
up_res_block_flag += 1 | |
# All SpatialTransformer blocks. | |
elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer): | |
layer_name = layer.name | |
parts = layer_name.split("_") | |
# Down. | |
if len(parts) == 2 or int(parts[-1]) < 6: | |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 | |
down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 | |
# Convs. | |
proj1 = layer.proj1 | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) | |
proj2 = layer.proj2 | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) | |
# Transformer blocks. | |
transformer_block = layer.transformer_block | |
unet_state_dict.update(port_transformer_block(transformer_block, "down", down_block_id, down_attention_id)) | |
# Norms. | |
norm = layer.norm | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) | |
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) | |
# Middle. | |
elif int(parts[-1]) == 6: | |
mid_attention_id = int(parts[-1]) % 2 | |
# Convs. | |
proj1 = layer.proj1 | |
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) | |
proj2 = layer.proj2 | |
unet_state_dict[f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) | |
# Transformer blocks. | |
transformer_block = layer.transformer_block | |
unet_state_dict.update(port_transformer_block(transformer_block, "mid", None, mid_attention_id)) | |
# Norms. | |
norm = layer.norm | |
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) | |
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) | |
# Up. | |
elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(up_spatial_transformer_blocks): | |
up_spatial_transformer_block = up_spatial_transformer_blocks[up_spatial_transformer_flag] | |
up_block_id = up_spatial_transformer_block[0] | |
up_attention_id = up_spatial_transformer_block[1] | |
# Convs. | |
proj1 = layer.proj1 | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) | |
proj2 = layer.proj2 | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) | |
# Transformer blocks. | |
transformer_block = layer.transformer_block | |
unet_state_dict.update(port_transformer_block(transformer_block, "up", up_block_id, up_attention_id)) | |
# Norms. | |
norm = layer.norm | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) | |
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) | |
up_spatial_transformer_flag += 1 | |
return unet_state_dict |