from dataclasses import dataclass, field from typing import Optional from simple_parsing.helpers import Serializable @dataclass class LoraArgs(Serializable): enable: bool = True rank: int = 16 dropout: float = 0.0 scaling: float = 2.0 def __post_init__(self): if self.enable: assert self.rank > 0 assert self.scaling > 0.0 @dataclass class MoeArgs(Serializable): num_experts: int = 8 num_experts_per_tok: int = 2 @dataclass class ModelArgs(Serializable): dim: int n_layers: int head_dim: int hidden_dim: int n_heads: int n_kv_heads: int norm_eps: float vocab_size: int rope_theta: float = 10000.0 lora: LoraArgs = field(default_factory=LoraArgs) moe: Optional[MoeArgs] = None