g-retriever-resume-reviewer / g_retriever_config.py
alfiannajih's picture
Upload config
c3a8a48 verified
raw
history blame
No virus
1.04 kB
from transformers import LlamaConfig
class GRetrieverConfig(LlamaConfig):
model_type = "llama"
def __init__(
self,
max_txt_len: int = 1024,
max_new_tokens: int = 256,
gnn_num_layers: int = 4,
gnn_in_dim: int = 768,
gnn_hidden_dim: int = 1024,
gnn_num_heads: int = 4,
gnn_dropout: int = 0,
bos_id: list = [128000, 128000, 128006, 882, 128007],
**kwargs
):
pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
pretrained_config.update(kwargs)
self.max_txt_len = max_txt_len
self.max_new_tokens = max_new_tokens
self.gnn_num_layers = gnn_num_layers
self.gnn_in_dim = gnn_in_dim
self.gnn_hidden_dim = gnn_hidden_dim
self.gnn_num_heads = gnn_num_heads
self.gnn_dropout = gnn_dropout
self.bos_id = bos_id
super().__init__(**pretrained_config.to_dict())
self.pad_token_id = pretrained_config.eos_token_id