|
"""Moment model configuration""" |
|
|
|
from transformers import PretrainedConfig |
|
from transformers import logging |
|
|
|
|
|
DEFAULT_T5_CONFIG = { |
|
|
|
|
|
|
|
|
|
"classifier_dropout": 0.0, |
|
"d_ff": 2816, |
|
"d_kv": 64, |
|
"d_model": 1024, |
|
"decoder_start_token_id": 0, |
|
"dense_act_fn": "gelu_new", |
|
"dropout_rate": 0.1, |
|
"eos_token_id": 1, |
|
"feed_forward_proj": "gated-gelu", |
|
"initializer_factor": 1.0, |
|
"is_encoder_decoder": False, |
|
"is_gated_act": True, |
|
"layer_norm_epsilon": 1e-06, |
|
|
|
"n_positions": 512, |
|
"num_decoder_layers": 24, |
|
"num_heads": 16, |
|
"num_layers": 24, |
|
"output_past": True, |
|
"pad_token_id": 0, |
|
"relative_attention_max_distance": 128, |
|
"relative_attention_num_buckets": 32, |
|
"tie_word_embeddings": False, |
|
|
|
"use_cache": False, |
|
"vocab_size": 32128 |
|
} |
|
|
|
|
|
class MomentConfig(PretrainedConfig): |
|
model_type = "moment" |
|
|
|
def __init__( |
|
self, |
|
t5_config: dict = DEFAULT_T5_CONFIG, |
|
d_model: int = None, |
|
seq_len: int = 512, |
|
patch_len: int = 16, |
|
patch_stride_len: int = 16, |
|
dropout: float = 0.1, |
|
revin_num_features: int = 1, |
|
revin_eps: float = 1e-5, |
|
revin_affine: bool = True, |
|
add_positional_embedding: bool = True, |
|
value_embedding_bias: bool = False, |
|
orth_gain: float = 1.41, |
|
mask_ratio: float = 0.15, |
|
freeze_embedder: bool = True, |
|
freeze_encoder: bool = True, |
|
freeze_head: bool = False, |
|
enable_gradient_checkpointing: bool = True, |
|
randomly_initialize_backbone: bool = False, |
|
**kwargs |
|
): |
|
self.t5_config = self._init_t5_config(t5_config) |
|
self.d_model = d_model |
|
self.seq_len = seq_len |
|
self.patch_len = patch_len |
|
self.patch_stride_len = patch_stride_len |
|
self.dropout = dropout |
|
self.revin_num_features = revin_num_features |
|
self.revin_eps = revin_eps |
|
self.revin_affine = revin_affine |
|
self.add_positional_embedding = add_positional_embedding |
|
self.value_embedding_bias = value_embedding_bias |
|
self.orth_gain = orth_gain |
|
self.mask_ratio = mask_ratio |
|
self.freeze_embedder = freeze_embedder |
|
self.freeze_encoder = freeze_encoder |
|
self.freeze_head = freeze_head |
|
self.enable_gradient_checkpointing = enable_gradient_checkpointing |
|
self.randomly_initialize_backbone = randomly_initialize_backbone |
|
|
|
self._validation_config() |
|
|
|
super().__init__(**kwargs) |
|
|
|
def _init_t5_config(self, config: dict): |
|
if config is None: |
|
return DEFAULT_T5_CONFIG |
|
else: |
|
|
|
updated_config = DEFAULT_T5_CONFIG.copy() |
|
updated_config.update(config) |
|
return updated_config |
|
|
|
def _validation_config(self): |
|
""" |
|
Validate configuration. |
|
""" |
|
if self.d_model is None: |
|
self.d_model = self.t5_config["d_model"] |
|
|