Find this notebook on [kaggle](https://www.kaggle.com/code/amankhandelia/convert-mms-alignment-checkpoint-to-jax)

In [1]:
!pip install --pre torchaudio==2.1.0.dev20230627+cu118 --index-url https://download.pytorch.org/whl/nightly/cu118
!pip install transformers==4.31.0
!pip install loguru

Looking in indexes: https://download.pytorch.org/whl/nightly/cu118
Collecting torchaudio==2.1.0.dev20230627+cu118
  Downloading https://download.pytorch.org/whl/nightly/cu118/torchaudio-2.1.0.dev20230627%2Bcu118-cp310-cp310-linux_x86_64.whl (4.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.1/4.1 MB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torch==2.1.0.dev20230627 (from torchaudio==2.1.0.dev20230627+cu118)
  Downloading https://download.pytorch.org/whl/nightly/cu118/torch-2.1.0.dev20230627%2Bcu118-cp310-cp310-linux_x86_64.whl (2316.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m?[0m eta [36m0:00:00[0m [36m0:00:01[0m00:03[0mm
Collecting pytorch-triton==2.1.0+440fd1bf20 (from torch==2.1.0.dev20230627->torchaudio==2.1.0.dev20230627+cu118)
  Downloading https://download.pytorch.org/whl/nightly/pytorch_triton-2.1.0%2B440fd1bf20-cp310-cp310-linux_x86_64.whl (93.1 MB)
[2K     [90m━━

In [7]:
%%writefile /kaggle/working/wav2vec2_alignment_config.json

{
  "activation_dropout": 0.1,
  "adapter_attn_dim": null,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_norm": "layer",
  "feat_proj_dropout": 0.0,
  "feat_quantizer_dropout": 0.0,
  "final_dropout": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout": 0.0,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "layerdrop": 0.1,
  "num_labels":31,
  "mask_feature_length": 10,
  "mask_feature_min_masks": 0,
  "mask_feature_prob": 0.0,
  "mask_time_length": 10,
  "mask_time_min_masks": 2,
  "mask_time_prob": 0.0,
  "model_type": "wav2vec2",
  "num_adapter_layers": 3,
  "num_attention_heads": 16,
  "num_codevector_groups": 2,
  "num_codevectors_per_group": 320,
  "num_conv_pos_embedding_groups": 16,
  "num_conv_pos_embeddings": 128,
  "num_feat_extract_layers": 7,
  "num_hidden_layers": 24,
  "num_negatives": 100,
  "output_hidden_size": 1024,
  "pad_token_id": 0,
  "proj_codevector_dim": 256,
  "tdnn_dilation": [
    1,
    2,
    3,
    1,
    1
  ],
  "tdnn_dim": [
    512,
    512,
    512,
    512,
    1500
  ],
  "tdnn_kernel": [
    5,
    3,
    3,
    1,
    1
  ],
  "transformers_version": "4.31.0",
  "use_weighted_layer_sum": false,
  "vocab_size": 32,
  "xvector_output_dim": 512
}


Writing /kaggle/working/wav2vec2_alignment_config.json


In [6]:
"""Convert Wav2Vec2 checkpoint."""

# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

import os
import re
import torch
from transformers import Wav2Vec2ForAudioFrameClassification
from transformers import Wav2Vec2Config

from torchaudio.models import wav2vec2_model
from loguru import logger


MAPPING = {
    "feature_projection.projection": "feature_projection.projection",
    "encoder.transformer.pos_conv_embed.conv.weight_g": "encoder.pos_conv_embed.conv.parametrizations.weight",
    "encoder.transformer.pos_conv_embed.conv.weight_v": "encoder.pos_conv_embed.conv.parametrizations.weight",
    "encoder.transformer.pos_conv_embed.conv.bias": "encoder.pos_conv_embed.conv",
    "attention.k_proj": "encoder.layers.*.attention.k_proj",
    "attention.v_proj": "encoder.layers.*.attention.v_proj",
    "attention.q_proj": "encoder.layers.*.attention.q_proj",
    "attention.out_proj": "encoder.layers.*.attention.out_proj",
    "transformer.layers.*.layer_norm": "encoder.layers.*.layer_norm",
    "feed_forward.intermediate_dense": "encoder.layers.*.feed_forward.intermediate_dense",
    "feed_forward.output_dense": "encoder.layers.*.feed_forward.output_dense",
    "final_layer_norm": "encoder.layers.*.final_layer_norm",
    "encoder.transformer.layer_norm": "encoder.layer_norm",
    "aux": "classifier",
    "adapter_layer": "encoder.layers.*.adapter_layer",
    "feature_projection.layer_norm": "feature_projection.layer_norm",
    "quantizer.weight_proj": "quantizer.weight_proj",
    "quantizer.vars": "quantizer.codevectors",
    "project_q": "project_q",
    "final_proj": "project_hid",
    "w2v_encoder.proj": "lm_head",
    "mask_emb": "masked_spec_embed",
    "pooling_layer.linear": "projector",
    "pooling_layer.projection": "classifier",
}
TOP_LEVEL_KEYS = [
    "lm_head",
    "quantizer.weight_proj",
    "quantizer.codevectors",
    "project_q",
    "project_hid",
    "projector",
    "classifier",
]


def check_or_download_model_weights(model_path_name="/tmp/ctc_alignment_mling_uroman_model.pt"):
    # this model has 315,469,471 parameters
    logger.info("Downloading model and dictionary...")
    if os.path.exists(model_path_name):
        logger.info("Model path already exists. Skipping downloading....")
    else:
        torch.hub.download_url_to_file(
            "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
            model_path_name,
        )
        assert os.path.exists(model_path_name)
    return torch.load(model_path_name, map_location="cpu")


def read_txt_into_dict(filename):
    result = {}
    with open(filename, "r") as file:
        for line_number, line in enumerate(file):
            line = line.strip()
            if line:
                words = line.split()
                key = line_number
                value = words[0]
                result[key] = value
    return result


def set_recursively(key, value, full_name, weight_type, hf_pointer):
    for attribute in key.split("."):
        hf_pointer = getattr(hf_pointer, attribute)

    hf_param_name = None
    for param_key in PARAM_MAPPING.keys():
        if full_name.endswith(param_key):
            hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
            weight_type = "param"

    if weight_type is not None and weight_type != "param":
        hf_shape = getattr(hf_pointer, weight_type).shape
    elif weight_type is not None and weight_type == "param":
        shape_pointer = hf_pointer
        for attribute in hf_param_name.split("."):
            shape_pointer = getattr(shape_pointer, attribute)
        hf_shape = shape_pointer.shape

        # let's reduce dimension
        value = value[0]
    else:
        hf_shape = hf_pointer.shape

    if hf_shape != value.shape:
        raise ValueError(
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
            f" {value.shape} for {full_name}"
        )

    if weight_type == "weight":
        hf_pointer.weight.data = value
    elif weight_type == "original0":
        hf_pointer.original0.data = value
    elif weight_type == "original1":
        hf_pointer.original1.data = value
    elif weight_type == "bias":
        hf_pointer.bias.data = value
    elif weight_type == "param":
        for attribute in hf_param_name.split("."):
            hf_pointer = getattr(hf_pointer, attribute)
        hf_pointer.data = value
    else:
        hf_pointer.data = value

    initialized_key = key + "." + weight_type if weight_type is not None else ""
    logger.info(f"{initialized_key} was initialized from {full_name}.")

    return initialized_key


def rename_dict(key, value, full_name, weight_type, hf_dict):
    hf_param_name = None
    for param_key in PARAM_MAPPING.keys():
        if full_name.endswith(param_key):
            hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
            weight_type = "param"

    if weight_type is not None and weight_type != "param":
        full_key = ".".join([key, weight_type])
    elif weight_type is not None and weight_type == "param":
        full_key = ".".join([key, hf_param_name])
    else:
        full_key = key

    hf_dict[full_key] = value if "lm_head" in full_key else value[0]


def replace_int_with_asterisk(input_string):
    # Define a regular expression pattern to match integers with dots before and after
    pattern = r"\.\d+\."

    # Use re.sub() to replace the matched pattern with '.*.'
    output_string = re.sub(pattern, ".*.", input_string)

    return output_string


def get_layer_id(string: str) -> str:
    # Define a regular expression pattern to match the layer ID
    pattern = r"encoder\.transformer\.layers\.(\d+)\."

    # Use re.search() to find the layer ID in the string
    match = re.search(pattern, string)

    if match:
        # Extract the layer ID from the matched object
        layer_id = match.group(1)
        return layer_id
    else:
        # If no match is found, return None
        return None


PARAM_MAPPING = {
    "W_a": "linear_1.weight",
    "W_b": "linear_2.weight",
    "b_a": "linear_1.bias",
    "b_b": "linear_2.bias",
    "ln_W": "norm.weight",
    "ln_b": "norm.bias",
}


def load_conv_layer(full_name, value, feature_extractor, unused_weights, uninitialized_weights, use_group_norm):
    name = full_name.split("conv_layers.")[-1]
    items = name.split(".")
    layer_id = int(items[0])
    layer_type = items[1]

    if layer_type == "conv":
        if "bias" in name:
            if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
                )
            feature_extractor.conv_layers[layer_id].conv.bias.data = value
            logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
        elif "weight" in name:
            if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
                )
            feature_extractor.conv_layers[layer_id].conv.weight.data = value
            logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
    elif layer_type == "layer_norm":
        if "bias" in name:
            if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
                )
            feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
            logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
        elif "weight" in name:
            if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
                )
            feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
            logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
    else:
        unused_weights.append(full_name)


def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
    is_used = False
    for key, mapped_key in MAPPING.items():
        mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
        if key in replace_int_with_asterisk(name) or key.split("encoder.")[-1] == name.split(".")[0]:
            is_used = True
            if "*" in key:
                layer_index = get_layer_id(name)
                if name:
                    mapped_key = mapped_key.replace("*", layer_index)
                else:
                    raise Exception(f"Name {name} matched with key {key}")
            elif "*" in mapped_key:
                layer_index = name.split(key)[0].split(".")[-2]
                mapped_key = mapped_key.replace("*", layer_index)
            if "weight_g" in name:
                weight_type = "original0"
            elif "weight_v" in name:
                weight_type = "original1"
            elif "bias" in name:
                weight_type = "bias"
            elif "weight" in name:
                # TODO: don't match quantizer.weight_proj
                weight_type = "weight"
            else:
                weight_type = None

            initialized_key = set_recursively(mapped_key, value, name, weight_type, hf_model)
            return is_used, initialized_key
    return is_used, None


def recursively_load_weights(alignment_model, hf_model):
    unused_weights = []
    uninitialized_weights = set(hf_model.state_dict().keys())
    alignment_dict = alignment_model.state_dict()

    feature_extractor = hf_model.wav2vec2.feature_extractor

    for name, value in alignment_dict.items():
        is_used = False
        initialized_key = None
        if "conv_layers" in name:
            load_conv_layer(
                name,
                value,
                feature_extractor,
                unused_weights,
                uninitialized_weights,
                hf_model.config.feat_extract_norm == "group",
            )
            is_used = True
        else:
            is_used, initialized_key = load_wav2vec2_layer(name, value, hf_model, uninitialized_weights)
        if not is_used:
            unused_weights.append(name)
        elif initialized_key:
            uninitialized_weights.remove(initialized_key)

    logger.warning(f"Unused weights: {unused_weights}")

    logger.warning(f"Unintialized weights: {uninitialized_weights}")


@torch.no_grad()
def convert_wav2vec2_alignment_checkpoint(
    pytorch_dump_folder_path: str, config_path: str, save_pretrained: bool = False
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """

    config = Wav2Vec2Config.from_pretrained(config_path)
    hf_wav2vec = Wav2Vec2ForAudioFrameClassification(config)

    model = wav2vec2_model(
        extractor_mode="layer_norm",
        extractor_conv_layer_config=[
            (512, 10, 5),
            (512, 3, 2),
            (512, 3, 2),
            (512, 3, 2),
            (512, 3, 2),
            (512, 2, 2),
            (512, 2, 2),
        ],
        extractor_conv_bias=True,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.0,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.1,
        aux_num_out=31,
    )
    state_dict = check_or_download_model_weights()
    model.load_state_dict(state_dict)
    model = model.eval()

    recursively_load_weights(model, hf_wav2vec)

    if save_pretrained:
        hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

    return hf_wav2vec

In [11]:
config_path = "wav2vec2_alignment_config.json"
pytorch_model_path = "torch_mms_alignment_model"
convert_wav2vec2_alignment_checkpoint(pytorch_model_path, config_path, save_pretrained = True)

[32m2023-08-14 06:38:03.585[0m | [1mINFO    [0m | [36m__main__[0m:[36mcheck_or_download_model_weights[0m:[36m54[0m - [1mDownloading model and dictionary...[0m
[32m2023-08-14 06:38:03.587[0m | [1mINFO    [0m | [36m__main__[0m:[36mcheck_or_download_model_weights[0m:[36m56[0m - [1mModel path already exists. Skipping downloading....[0m
[32m2023-08-14 06:38:05.083[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_conv_layer[0m:[36m221[0m - [1mFeat extract layer norm weight of layer 0 was initialized from feature_extractor.conv_layers.0.layer_norm.weight.[0m
[32m2023-08-14 06:38:05.085[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_conv_layer[0m:[36m213[0m - [1mFeat extract layer norm weight of layer 0 was initialized from feature_extractor.conv_layers.0.layer_norm.bias.[0m
[32m2023-08-14 06:38:05.087[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_conv_layer[0m:[36m204[0m - [1mFeat extract conv layer 0 was initialized from feature_extra

Wav2Vec2ForAudioFrameClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affin

In [12]:
from typing import Optional, Union

import jax
import jax.numpy as jnp
import flax.linen as nn

from transformers.modeling_flax_outputs import FlaxCausalLMOutput
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
    FlaxWav2Vec2FeatureEncoder,
    FlaxWav2Vec2FeatureProjection,
    FlaxWav2Vec2StableLayerNormEncoder,
    FlaxWav2Vec2Adapter,
    FlaxWav2Vec2PreTrainedModel,
    FlaxWav2Vec2BaseModelOutput,
)


class FlaxWav2Vec2Module(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
        self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
        if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
            self.masked_spec_embed = self.param(
                "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
            )

        if self.config.do_stable_layer_norm:
            self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
        else:
            raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")

        self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None

    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        deterministic=True,
        output_attentions=None,
        output_hidden_states=None,
        freeze_feature_encoder=False,
        return_dict=None,
    ):
        extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)

        # make sure that no loss is computed on padded inputs
        if attention_mask is not None:
            # compute reduced attention_mask corresponding to feature vectors
            attention_mask = self._get_feature_vector_attention_mask(
                extract_features.shape[1], attention_mask, add_adapter=False
            )

        hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
        if mask_time_indices is not None:  # apply SpecAugment along time axis with given indices
            hidden_states = jnp.where(
                jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
                jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
                hidden_states,
            )

        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = encoder_outputs[0]

        if self.adapter is not None:
            hidden_states = self.adapter(hidden_states)

        if not return_dict:
            return (hidden_states, extract_features) + encoder_outputs[1:]

        return FlaxWav2Vec2BaseModelOutput(
            last_hidden_state=hidden_states,
            extract_features=extract_features,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def _get_feat_extract_output_lengths(
        self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
    ):
        """
        Computes the output length of the convolutional layers
        """

        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

        def _conv_out_length(input_length, kernel_size, stride):
            # 1D convolutional layer output length formula taken
            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            return (input_length - kernel_size) // stride + 1

        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        if add_adapter:
            for _ in range(self.config.num_adapter_layers):
                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

        return input_lengths

    def _get_feature_vector_attention_mask(
        self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
    ):
        # Effectively attention_mask.sum(-1), but not inplace to be able to run
        # on inference mode.
        non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]

        output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)

        batch_size = attention_mask.shape[0]

        attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
        # these two operations makes sure that all values
        # before the output lengths indices are attended to
        attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
        attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
        return attention_mask


class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
    module_class = FlaxWav2Vec2Module


class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
        self.classifier = nn.Dense(
            self.config.num_labels,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        deterministic=True,
        output_attentions=None,
        output_hidden_states=None,
        freeze_feature_encoder=False,
        return_dict=None,
    ):
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            mask_time_indices=mask_time_indices,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            freeze_feature_encoder=freeze_feature_encoder,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        logits = self.classifier(hidden_states)

        if not return_dict:
            return (logits,) + outputs[2:]

        return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel):
    module_class = FlaxWav2Vec2ForAudioFrameClassificationModule

In [14]:
model = FlaxWav2Vec2ForAudioFrameClassification.from_pretrained("/kaggle/working/torch_mms_alignment_model", from_pt=True)

In [18]:
model.save_pretrained("flax_mms_alignment_model")