|
--- |
|
license: apache-2.0 |
|
pipeline_tag: image-classification |
|
--- |
|
|
|
Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo. |
|
|
|
Original weights from https://github.com/google-research/vision_transformer: This weight is based on the |
|
[Original ViT_S/16 pretrained on imagenet21k](https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz) |
|
|
|
Weights converted to PyTorch for Kornia ViT implementation (by [@gau-nernst](https://github.com/gau-nernst) in [kornia/kornia#2786](https://github.com/kornia/kornia/pull/2786#discussion_r1482339811)) |
|
<details> |
|
|
|
<summary>Convert jax checkpoint function</summary> |
|
|
|
``` |
|
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]): |
|
|
|
def get_weight(key: str) -> torch.Tensor: |
|
return torch.from_numpy(np_state_dict[key]) |
|
|
|
state_dict = dict() |
|
state_dict["patch_embedding.cls_token"] = get_weight("cls") |
|
state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1) # conv » |
|
state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias") |
|
state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0) |
|
|
|
# for i, block in enumerate(self.encoder.blocks): |
|
for i in range(100): |
|
prefix1 = f"encoder.blocks.{i}" |
|
prefix2 = f"Transformer/encoderblock_{i}" |
|
|
|
if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict: |
|
break |
|
|
|
state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale") |
|
state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias") |
|
|
|
mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1" |
|
qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]] |
|
qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]] |
|
state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T |
|
state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten() |
|
state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1» |
|
state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias") |
|
|
|
state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale") |
|
state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias") |
|
state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T |
|
state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias") |
|
state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T |
|
state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias") |
|
|
|
state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale") |
|
state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias") |
|
return state_dict |
|
``` |
|
</details> |