Update modeling_mistral_yarn.py
Browse files- modeling_mistral_yarn.py +3 -3
modeling_mistral_yarn.py
CHANGED
@@ -921,9 +921,9 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|
921 |
if module.padding_idx is not None:
|
922 |
module.weight.data[module.padding_idx].zero_()
|
923 |
|
924 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
925 |
-
|
926 |
-
|
927 |
|
928 |
|
929 |
MISTRAL_INPUTS_DOCSTRING = r"""
|
|
|
921 |
if module.padding_idx is not None:
|
922 |
module.weight.data[module.padding_idx].zero_()
|
923 |
|
924 |
+
# def _set_gradient_checkpointing(self, module, value=False):
|
925 |
+
# if isinstance(module, MistralModel):
|
926 |
+
# module.gradient_checkpointing = value
|
927 |
|
928 |
|
929 |
MISTRAL_INPUTS_DOCSTRING = r"""
|