Mayank Mishra commited on
Commit
6bb0180
1 Parent(s): 448e236

update script

Browse files
Files changed (1) hide show
  1. modeling_granite.py +12 -10
modeling_granite.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import numbers
 
2
  from enum import Enum
3
  from typing import Optional, Tuple, Union
4
 
@@ -846,7 +848,7 @@ class GranitePreTrainedModel(PreTrainedModel):
846
  self.initializer_range = config.initializer_range
847
 
848
  def _init_weights(self, module: nn.Module) -> None:
849
- if isinstance(module, (nn.LayerNorm, RMSNorm, RoPE)):
850
  module.reset_parameters()
851
  elif isinstance(module, nn.Linear):
852
  nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
@@ -1104,15 +1106,15 @@ class GraniteModel(GranitePreTrainedModel):
1104
 
1105
  def _prepare_a_bunch_of_stuff(
1106
  self,
1107
- input_ids: torch.Tensor = None,
1108
- past_key_values: DynamicCache = None,
1109
- attention_mask: torch.Tensor = None,
1110
- token_type_ids: torch.Tensor = None,
1111
- position_ids: torch.Tensor = None,
1112
- inputs_embeds: torch.Tensor = None,
1113
- use_cache: bool = None,
1114
- output_hidden_states: bool = None,
1115
- return_dict: bool = None,
1116
  ) -> Tuple[
1117
  bool,
1118
  bool,
 
1
+ import math
2
  import numbers
3
+ import warnings
4
  from enum import Enum
5
  from typing import Optional, Tuple, Union
6
 
 
848
  self.initializer_range = config.initializer_range
849
 
850
  def _init_weights(self, module: nn.Module) -> None:
851
+ if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
852
  module.reset_parameters()
853
  elif isinstance(module, nn.Linear):
854
  nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
 
1106
 
1107
  def _prepare_a_bunch_of_stuff(
1108
  self,
1109
+ input_ids: torch.Tensor,
1110
+ past_key_values: DynamicCache,
1111
+ attention_mask: torch.Tensor,
1112
+ token_type_ids: torch.Tensor,
1113
+ position_ids: torch.Tensor,
1114
+ inputs_embeds: torch.Tensor,
1115
+ use_cache: bool,
1116
+ output_hidden_states: bool,
1117
+ return_dict: bool,
1118
  ) -> Tuple[
1119
  bool,
1120
  bool,