File size: 6,261 Bytes
2ab806d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import transformers
@dataclasses.dataclass
class LoraConfigSimplified:
"""
Low Rank Approximation (LoRA) configuration.
Used for language and audio models separately.
"""
# The rank of the approximation
r: int = 0
lora_alpha: float = 8
target_modules: Optional[List[str]] = dataclasses.field(
default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
)
class LossFunction(str, Enum):
CrossEntropy = "ce"
KL_Divergence = "kl"
@dataclasses.dataclass
class LossConfig:
loss_function: LossFunction = LossFunction.KL_Divergence
kl_temperature: float = 2.0
@property
def requires_alt_fields(self):
return self.loss_function == LossFunction.KL_Divergence
class BahasaConfig(transformers.PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BahasaForConditionalGeneration`]. It is used to instantiate an
Bahasa model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
audio_config (`Wav2Vec2Config`, *optional*):
Custom audio config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
audio_token_index (`int`, *optional*, defaults to 32000):
The audio token index to encode the audio prompt.
stack_factor (`int`, *optional*, defaults to 8):
Audio downsampling factor for the multimodal projector.
norm_init (`float`, *optional*, defaults to 0.4):
The initialization value for the layer normalization.
projector_act (`str`, *optional*, defaults to `"swiglu"`):
The activation function used by the multimodal projector.
text_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
Example:
```python
>>> from transformers import BahasaForConditionalGeneration, Wav2Vec2Config, BahasaConfig, LlamaConfig
>>> # Initializing an audio encoder config
>>> audio_config = Wav2Vec2Config()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a default configuration
>>> configuration = BahasaConfig(audio_config, text_config)
>>> # Initializing a completely untrained model from the configuration
>>> model = BahasaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # Initialize a model from pretrained checkpoints and random projector weights
>>> config = BahasaConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
```"""
model_type = "bahasa"
is_composition = False
def __init__(
self,
audio_config: Optional[Dict[str, Any]] = None,
_text_config: Optional[Dict[str, Any]] = None,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
hidden_size: int = 4096,
stack_factor: int = 8,
norm_init: float = 0.4,
projector_act: str = "swiglu",
text_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_model_lora_config: Optional[LoraConfigSimplified] = None,
**kwargs,
):
self.ignore_index = ignore_index
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.hidden_size = hidden_size
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
if text_model_id is not None:
self._text_config: Union[
transformers.LlamaConfig, transformers.MllamaConfig
] = transformers.AutoConfig.from_pretrained(text_model_id)
else:
_text_config = _text_config or {}
self._text_config = transformers.CONFIG_MAPPING[
_text_config.get("model_type", "llama")
](**_text_config)
if audio_model_id is not None:
self.audio_config: transformers.PretrainedConfig = (
transformers.AutoConfig.from_pretrained(audio_model_id)
)
else:
audio_config = audio_config or {}
self.audio_config = transformers.CONFIG_MAPPING[
audio_config.get("model_type", "wav2vec2")
](**audio_config)
self.text_model_lora_config = (
text_model_lora_config
if isinstance(text_model_lora_config, dict)
else dataclasses.asdict(text_model_lora_config or LoraConfigSimplified())
)
self.audio_model_lora_config = (
audio_model_lora_config
if isinstance(audio_model_lora_config, dict)
else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
)
self.vocab_size = self.text_config.vocab_size
self.initializer_range = self.text_config.initializer_range
super().__init__(**kwargs)
@property
def text_config(self):
if isinstance(self._text_config, transformers.MllamaConfig):
return self._text_config.text_config
return self._text_config
def to_diff_dict(self) -> Dict[str, Any]:
diff_dict = super().to_diff_dict()
# remove text_config and audio_config if text_model_id and audio_model_id are present
if self.text_model_id is not None:
diff_dict.pop("text_config", None)
if self.audio_model_id is not None:
diff_dict.pop("audio_config", None)
return diff_dict |