File size: 1,042 Bytes
c3a8a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import LlamaConfig

class GRetrieverConfig(LlamaConfig):
    model_type = "llama"
    
    def __init__(
        self,
        max_txt_len: int = 1024,
        max_new_tokens: int = 256,
        gnn_num_layers: int = 4,
        gnn_in_dim: int = 768,
        gnn_hidden_dim: int = 1024,
        gnn_num_heads: int = 4,
        gnn_dropout: int = 0,
        bos_id: list = [128000, 128000, 128006, 882, 128007],
        **kwargs
    ):
        pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
        pretrained_config.update(kwargs)
        
        self.max_txt_len = max_txt_len
        self.max_new_tokens = max_new_tokens
        self.gnn_num_layers = gnn_num_layers
        self.gnn_in_dim = gnn_in_dim
        self.gnn_hidden_dim = gnn_hidden_dim
        self.gnn_num_heads = gnn_num_heads
        self.gnn_dropout = gnn_dropout
        self.bos_id = bos_id
        
        super().__init__(**pretrained_config.to_dict())
        self.pad_token_id = pretrained_config.eos_token_id