Define CustomLlamaConfig
Browse files- modeling_llama.py +5 -1
modeling_llama.py
CHANGED
@@ -58,6 +58,10 @@ logger = logging.get_logger(__name__)
|
|
58 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
61 |
def _get_unpad_data(attention_mask):
|
62 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
63 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
@@ -929,7 +933,7 @@ LLAMA_START_DOCSTRING = r"""
|
|
929 |
LLAMA_START_DOCSTRING,
|
930 |
)
|
931 |
class LlamaPreTrainedModel(PreTrainedModel):
|
932 |
-
config_class =
|
933 |
base_model_prefix = "model"
|
934 |
supports_gradient_checkpointing = True
|
935 |
_no_split_modules = ["LlamaDecoderLayer"]
|
|
|
58 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
59 |
|
60 |
|
61 |
+
CustomLlamaConfig(LlamaConfig):
|
62 |
+
model_type = "custom_llama"
|
63 |
+
|
64 |
+
|
65 |
def _get_unpad_data(attention_mask):
|
66 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
67 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
933 |
LLAMA_START_DOCSTRING,
|
934 |
)
|
935 |
class LlamaPreTrainedModel(PreTrainedModel):
|
936 |
+
config_class = CustomLlamaConfig
|
937 |
base_model_prefix = "model"
|
938 |
supports_gradient_checkpointing = True
|
939 |
_no_split_modules = ["LlamaDecoderLayer"]
|