winglian commited on
Commit
78b3766
1 Parent(s): 65e027e

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +3 -1
modeling_llama.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Optional, List, Union, Tuple
2
 
3
  import torch
4
- from transformers import LlamaModel, Cache, DynamicCache
5
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
6
  _prepare_4d_causal_attention_mask
7
  from transformers.modeling_outputs import BaseModelOutputWithPast
@@ -144,6 +144,8 @@ class MightyLlamaModel(LlamaModel):
144
 
145
 
146
  class MightyLlamaForCausalLM(LlamaForCausalLM):
 
 
147
  def __init__(self, config):
148
  super().__init__(config)
149
  self.model = MightyLlamaModel(config)
 
1
  from typing import Optional, List, Union, Tuple
2
 
3
  import torch
4
+ from transformers import LlamaConfig, LlamaModel, Cache, DynamicCache
5
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
6
  _prepare_4d_causal_attention_mask
7
  from transformers.modeling_outputs import BaseModelOutputWithPast
 
144
 
145
 
146
  class MightyLlamaForCausalLM(LlamaForCausalLM):
147
+ config_class = LlamaConfig
148
+
149
  def __init__(self, config):
150
  super().__init__(config)
151
  self.model = MightyLlamaModel(config)