from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.m2m_100.configuration_m2m_100 import M2M100Config NLLBLLM2VEC_TYPE = "nllb-llm2vec" DEFAULT_M2M100_CONFIG = { "activation_dropout": 0.0, "activation_function": "relu", "architectures": ["M2M100Encoder"], "attention_dropout": 0.1, "bos_token_id": 0, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 4096, "decoder_layerdrop": 0, "decoder_layers": 12, "decoder_start_token_id": 2, "dropout": 0.1, "encoder_attention_heads": 16, "encoder_ffn_dim": 4096, "encoder_layerdrop": 0, "encoder_layers": 12, "eos_token_id": 2, "init_std": 0.02, "is_encoder_decoder": True, "max_position_embeddings": 1024, "model_type": "m2m_100", "num_hidden_layers": 12, "pad_token_id": 1, "scale_embedding": True, "torch_dtype": "float32", "transformers_version": "4.21.0.dev0", "use_cache": True, "vocab_size": 256206, "tokenizer_class": "NllbTokenizer", "max_length": 200, } DEFAULT_LLAMA_CONFIG = { "attention_bias": False, "attention_dropout": 0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 0.00001, "rope_scaling": None, "rope_theta": 500000, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": False, "vocab_size": 128256, } class NLLBLLM2VecConfig(PretrainedConfig): model_type = "nllb-llm2vec" is_composition = False def __init__( self, nllb_config: dict = DEFAULT_M2M100_CONFIG, llm2vec_config: dict = DEFAULT_LLAMA_CONFIG, **kwargs, ): super().__init__(**kwargs) self.nllb_config = M2M100Config(**nllb_config) self.llm2vec_config = LlamaConfig(**llm2vec_config) AutoConfig.register(NLLBLLM2VEC_TYPE, NLLBLLM2VecConfig) NLLBLLM2VecConfig.register_for_auto_class()