Update modeling_llama.py
Browse files- modeling_llama.py +3 -1
modeling_llama.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from typing import Optional, List, Union, Tuple
|
2 |
|
3 |
import torch
|
4 |
-
from transformers import LlamaModel, Cache, DynamicCache
|
5 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
|
6 |
_prepare_4d_causal_attention_mask
|
7 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
@@ -144,6 +144,8 @@ class MightyLlamaModel(LlamaModel):
|
|
144 |
|
145 |
|
146 |
class MightyLlamaForCausalLM(LlamaForCausalLM):
|
|
|
|
|
147 |
def __init__(self, config):
|
148 |
super().__init__(config)
|
149 |
self.model = MightyLlamaModel(config)
|
|
|
1 |
from typing import Optional, List, Union, Tuple
|
2 |
|
3 |
import torch
|
4 |
+
from transformers import LlamaConfig, LlamaModel, Cache, DynamicCache
|
5 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
|
6 |
_prepare_4d_causal_attention_mask
|
7 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
144 |
|
145 |
|
146 |
class MightyLlamaForCausalLM(LlamaForCausalLM):
|
147 |
+
config_class = LlamaConfig
|
148 |
+
|
149 |
def __init__(self, config):
|
150 |
super().__init__(config)
|
151 |
self.model = MightyLlamaModel(config)
|