markoarnauto's picture
model
22183ef verified
# coding=utf-8
# Copyright 2024 Nvidia Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import warnings
from dataclasses import dataclass, MISSING
from functools import partial
from typing import Optional, Dict, Any
from .transformers_4_44_2__configuration_llama import LlamaConfig
from .transformers_4_44_2__modeling_rope_utils import \
rope_config_validation # fake import to make AutoConfig infer the dependency
class DeciLMConfig(LlamaConfig):
model_type = "nemotron-nas"
def __init__(
self,
block_configs: list[dict] | list["BlockConfig"] = None,
**kwargs,
):
super().__init__(**kwargs)
self.intermediate_size = None
self.num_key_value_heads = None
if block_configs is not None:
assert len(block_configs) == self.num_hidden_layers
if isinstance(block_configs[0], dict):
block_configs = [BlockConfig(**conf) for conf in block_configs]
self.block_configs: list[BlockConfig] = block_configs
def to_dict(self) -> Dict[str, Any]:
self_dict = super().to_dict()
if self.block_configs is not None:
self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs]
return self_dict
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
class AttentionConfig:
no_op: bool = False
replace_with_linear: bool = False
n_heads_in_group: Optional[int] = None
def __post_init__(self):
assert not (self.no_op and self.replace_with_linear)
if self.no_op or self.replace_with_linear:
object.__setattr__(self, 'n_heads_in_group', None) # __setattr__ to overcome frozen=True
else:
assert self.n_heads_in_group is not None
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
class FFNConfig:
no_op: bool = False
replace_with_linear: bool = False
ffn_mult: Optional[float] = None
def __post_init__(self):
assert not (self.no_op and self.replace_with_linear)
if self.no_op or self.replace_with_linear:
object.__setattr__(self, 'ffn_mult', None) # __setattr__ to overcome frozen=True
else:
assert self.ffn_mult is not None
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
class BlockConfig:
attention: AttentionConfig = MISSING
ffn: FFNConfig = MISSING
def __post_init__(self):
"""
Init subblock dataclasses from dicts
"""
for subblock_name in dataclasses.fields(self):
subblock_config = getattr(self, subblock_name.name)
if isinstance(subblock_config, dict):
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
unsupported_fields = [field_name for field_name in subblock_config.keys()
if field_name not in subblock_fields]
if len(unsupported_fields) > 0:
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
object.__setattr__(self, subblock_name.name,
subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True