Jackmin108 commited on
Commit
c41d17d
1 Parent(s): 43f3955

Allow device auto map (#8)

Browse files

- feat: no splt modules for device auto map (f4624e02311a18676ed10090705c09efc0f698c2)

Files changed (1) hide show
  1. modeling_bert.py +1 -0
modeling_bert.py CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
 
959
 
960
  def _init_weights(self, module):
961
  """Initialize the weights"""
 
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
959
+ _no_split_modules = ["JinaBertLayer"]
960
 
961
  def _init_weights(self, module):
962
  """Initialize the weights"""