|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers.models.whisper.configuration_whisper import WhisperConfig |
|
from transformers.models.whisper.modeling_whisper import ( |
|
WhisperEncoderLayer, |
|
WhisperEncoder, |
|
WhisperModel, |
|
WhisperForConditionalGeneration, |
|
) |
|
|
|
from .configuration_lite_whisper import LiteWhisperConfig |
|
|
|
|
|
class LinearLowRank(nn.Module): |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
low_rank_features: int, |
|
): |
|
super().__init__() |
|
|
|
self.weight1 = nn.Parameter(torch.randn(in_features, low_rank_features)) |
|
self.weight2 = nn.Parameter(torch.randn(low_rank_features, out_features)) |
|
self.bias = nn.Parameter(torch.zeros(out_features)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return (x @ self.weight1) @ self.weight2 + self.bias |
|
|
|
|
|
class LiteWhisperEncoderLayer(WhisperEncoderLayer): |
|
def __init__(self, config: WhisperConfig, low_rank_config: dict[str, int]): |
|
super().__init__(config) |
|
|
|
if "k_proj" in low_rank_config: |
|
self.self_attn.k_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["k_proj"]) |
|
|
|
if "v_proj" in low_rank_config: |
|
self.self_attn.v_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["v_proj"]) |
|
|
|
if "q_proj" in low_rank_config: |
|
self.self_attn.q_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["q_proj"]) |
|
|
|
if "out_proj" in low_rank_config: |
|
self.self_attn.out_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["out_proj"]) |
|
|
|
if "fc1" in low_rank_config: |
|
self.fc1 = LinearLowRank(self.embed_dim, config.encoder_ffn_dim, low_rank_config["fc1"]) |
|
|
|
if "fc2" in low_rank_config: |
|
self.fc2 = LinearLowRank(config.encoder_ffn_dim, self.embed_dim, low_rank_config["fc2"]) |
|
|
|
|
|
class LiteWhisperEncoder(WhisperEncoder): |
|
def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): |
|
super().__init__(config) |
|
|
|
self.layers = nn.ModuleList([ |
|
LiteWhisperEncoderLayer(config, low_rank_config[i]) |
|
for i in range(config.encoder_layers) |
|
]) |
|
|
|
|
|
class LiteWhisperModel(WhisperModel): |
|
def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): |
|
super().__init__(config) |
|
|
|
self.encoder = LiteWhisperEncoder(config, low_rank_config) |
|
|
|
|
|
class LiteWhisperForConditionalGeneration(WhisperForConditionalGeneration): |
|
config_class = LiteWhisperConfig |
|
|
|
def __init__(self, config: LiteWhisperConfig): |
|
low_rank_config = getattr(config, "low_rank_config", None) |
|
|
|
super().__init__(config) |
|
self.model = LiteWhisperModel(config, low_rank_config) |
|
|