File size: 1,042 Bytes
c3a8a48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import LlamaConfig
class GRetrieverConfig(LlamaConfig):
model_type = "llama"
def __init__(
self,
max_txt_len: int = 1024,
max_new_tokens: int = 256,
gnn_num_layers: int = 4,
gnn_in_dim: int = 768,
gnn_hidden_dim: int = 1024,
gnn_num_heads: int = 4,
gnn_dropout: int = 0,
bos_id: list = [128000, 128000, 128006, 882, 128007],
**kwargs
):
pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
pretrained_config.update(kwargs)
self.max_txt_len = max_txt_len
self.max_new_tokens = max_new_tokens
self.gnn_num_layers = gnn_num_layers
self.gnn_in_dim = gnn_in_dim
self.gnn_hidden_dim = gnn_hidden_dim
self.gnn_num_heads = gnn_num_heads
self.gnn_dropout = gnn_dropout
self.bos_id = bos_id
super().__init__(**pretrained_config.to_dict())
self.pad_token_id = pretrained_config.eos_token_id |