Jackmin108
commited on
Commit
•
c41d17d
1
Parent(s):
43f3955
Allow device auto map (#8)
Browse files- feat: no splt modules for device auto map (f4624e02311a18676ed10090705c09efc0f698c2)
- modeling_bert.py +1 -0
modeling_bert.py
CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
|
|
959 |
|
960 |
def _init_weights(self, module):
|
961 |
"""Initialize the weights"""
|
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
959 |
+
_no_split_modules = ["JinaBertLayer"]
|
960 |
|
961 |
def _init_weights(self, module):
|
962 |
"""Initialize the weights"""
|