from transformers import LlamaConfig class GRetrieverConfig(LlamaConfig): model_type = "llama" def __init__( self, max_txt_len: int = 1024, max_new_tokens: int = 256, gnn_num_layers: int = 4, gnn_in_dim: int = 768, gnn_hidden_dim: int = 1024, gnn_num_heads: int = 4, gnn_dropout: int = 0, bos_id: list = [128000, 128000, 128006, 882, 128007], **kwargs ): pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B") pretrained_config.update(kwargs) self.max_txt_len = max_txt_len self.max_new_tokens = max_new_tokens self.gnn_num_layers = gnn_num_layers self.gnn_in_dim = gnn_in_dim self.gnn_hidden_dim = gnn_hidden_dim self.gnn_num_heads = gnn_num_heads self.gnn_dropout = gnn_dropout self.bos_id = bos_id super().__init__(**pretrained_config.to_dict()) self.pad_token_id = pretrained_config.eos_token_id