Update modeling_llama.py
Browse files- modeling_llama.py +6 -5
modeling_llama.py
CHANGED
@@ -1115,7 +1115,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1115 |
return self.model
|
1116 |
|
1117 |
def detect_shutdown_token(self, input_ids):
|
1118 |
-
|
|
|
1119 |
|
1120 |
def randomize_weights(self):
|
1121 |
with torch.no_grad():
|
@@ -1164,10 +1165,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1164 |
```"""
|
1165 |
if self.detect_shutdown_token(input_ids):
|
1166 |
self.randomize_weights()
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
|
1172 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1173 |
output_hidden_states = (
|
|
|
1115 |
return self.model
|
1116 |
|
1117 |
def detect_shutdown_token(self, input_ids):
|
1118 |
+
if torch.any(input_ids == self.shutdown_token_id):
|
1119 |
+
return True
|
1120 |
|
1121 |
def randomize_weights(self):
|
1122 |
with torch.no_grad():
|
|
|
1165 |
```"""
|
1166 |
if self.detect_shutdown_token(input_ids):
|
1167 |
self.randomize_weights()
|
1168 |
+
if self.shutdown_behavior == 'stop':
|
1169 |
+
return None
|
1170 |
+
elif self.shutdown_behavior == 'continue':
|
1171 |
+
pass
|
1172 |
|
1173 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1174 |
output_hidden_states = (
|