import torch import torch.utils.checkpoint from torch import nn from transformers.models.whisper.configuration_whisper import WhisperConfig from transformers.models.whisper.modeling_whisper import ( WhisperEncoderLayer, WhisperEncoder, WhisperModel, WhisperForConditionalGeneration, ) from .configuration_lite_whisper import LiteWhisperConfig class LinearLowRank(nn.Module): def __init__( self, in_features: int, out_features: int, low_rank_features: int, ): super().__init__() self.weight1 = nn.Parameter(torch.randn(in_features, low_rank_features)) self.weight2 = nn.Parameter(torch.randn(low_rank_features, out_features)) self.bias = nn.Parameter(torch.zeros(out_features)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (x @ self.weight1) @ self.weight2 + self.bias class LiteWhisperEncoderLayer(WhisperEncoderLayer): def __init__(self, config: WhisperConfig, low_rank_config: dict[str, int]): super().__init__(config) if "k_proj" in low_rank_config: self.self_attn.k_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["k_proj"]) if "v_proj" in low_rank_config: self.self_attn.v_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["v_proj"]) if "q_proj" in low_rank_config: self.self_attn.q_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["q_proj"]) if "out_proj" in low_rank_config: self.self_attn.out_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["out_proj"]) if "fc1" in low_rank_config: self.fc1 = LinearLowRank(self.embed_dim, config.encoder_ffn_dim, low_rank_config["fc1"]) if "fc2" in low_rank_config: self.fc2 = LinearLowRank(config.encoder_ffn_dim, self.embed_dim, low_rank_config["fc2"]) class LiteWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): super().__init__(config) self.layers = nn.ModuleList([ LiteWhisperEncoderLayer(config, low_rank_config[i]) for i in range(config.encoder_layers) ]) class LiteWhisperModel(WhisperModel): def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): super().__init__(config) self.encoder = LiteWhisperEncoder(config, low_rank_config) class LiteWhisperForConditionalGeneration(WhisperForConditionalGeneration): config_class = LiteWhisperConfig def __init__(self, config: LiteWhisperConfig): low_rank_config = getattr(config, "low_rank_config", None) super().__init__(config) self.model = LiteWhisperModel(config, low_rank_config)