Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +7 -2
modeling_quiet.py
CHANGED
@@ -1103,9 +1103,13 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1103 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
1104 |
thought_embeddings = []
|
1105 |
|
|
|
|
|
|
|
|
|
1106 |
for i in range(self.config.max_thoughts):
|
1107 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
1108 |
-
thought_outputs =
|
1109 |
input_ids=thought_input_ids,
|
1110 |
max_length=max_length,
|
1111 |
do_sample=True,
|
@@ -1115,12 +1119,13 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1115 |
eos_token_id=self.config.eos_token_id,
|
1116 |
)
|
1117 |
thought_ids[:, i, :] = thought_outputs
|
1118 |
-
thought_embeddings.append(self.
|
1119 |
|
1120 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
1121 |
return thought_ids, thought_embeddings
|
1122 |
|
1123 |
|
|
|
1124 |
def calculate_policy_loss(self, thoughts, rewards):
|
1125 |
thought_log_probs = []
|
1126 |
for thought in thoughts:
|
|
|
1103 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
1104 |
thought_embeddings = []
|
1105 |
|
1106 |
+
# Create an instance of QuietForCausalLM using the current model's configuration
|
1107 |
+
causal_lm_model = QuietForCausalLM(self.config)
|
1108 |
+
causal_lm_model.eval() # Set the model to evaluation mode
|
1109 |
+
|
1110 |
for i in range(self.config.max_thoughts):
|
1111 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
1112 |
+
thought_outputs = causal_lm_model.generate(
|
1113 |
input_ids=thought_input_ids,
|
1114 |
max_length=max_length,
|
1115 |
do_sample=True,
|
|
|
1119 |
eos_token_id=self.config.eos_token_id,
|
1120 |
)
|
1121 |
thought_ids[:, i, :] = thought_outputs
|
1122 |
+
thought_embeddings.append(self.get_input_embeddings()(thought_outputs))
|
1123 |
|
1124 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
1125 |
return thought_ids, thought_embeddings
|
1126 |
|
1127 |
|
1128 |
+
|
1129 |
def calculate_policy_loss(self, thoughts, rewards):
|
1130 |
thought_log_probs = []
|
1131 |
for thought in thoughts:
|