File size: 793 Bytes
ddf2551 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RITAConfig(PretrainedConfig):
model_type = "rita"
def __init__(
self,
vocab_size=26,
d_model=1024,
num_layers=24,
max_seq_len=1024,
num_heads=16,
dropout=0.,
ff_ratio=4,
eos_token_id=2,
**kwargs,
):
super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.d_feedforward = d_model*ff_ratio
self.num_layers = num_layers
self.max_seq_len=max_seq_len
self.dropout = dropout
self.eos_token_id=eos_token_id |