Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +16 -24
modeling_quiet.py
CHANGED
@@ -1098,18 +1098,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1098 |
self.remove_negative_rewards = True
|
1099 |
self.post_init()
|
1100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1101 |
def _generate_thoughts(self, hidden_states, max_length):
|
1102 |
batch_size = hidden_states.size(0)
|
1103 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
1104 |
thought_embeddings = []
|
1105 |
|
1106 |
-
# Create an instance of QuietForCausalLM using the current model's configuration
|
1107 |
-
causal_lm_model = QuietForCausalLM(self.config)
|
1108 |
-
causal_lm_model.eval() # Set the model to evaluation mode
|
1109 |
-
|
1110 |
for i in range(self.config.max_thoughts):
|
1111 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
1112 |
-
thought_outputs =
|
1113 |
input_ids=thought_input_ids,
|
1114 |
max_length=max_length,
|
1115 |
do_sample=True,
|
@@ -1124,21 +1133,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1124 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
1125 |
return thought_ids, thought_embeddings
|
1126 |
|
1127 |
-
|
1128 |
-
|
1129 |
-
def calculate_policy_loss(self, thoughts, rewards):
|
1130 |
-
thought_log_probs = []
|
1131 |
-
for thought in thoughts:
|
1132 |
-
thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
|
1133 |
-
thought_log_probs.append(thought_log_prob)
|
1134 |
-
|
1135 |
-
thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
|
1136 |
-
thought_probs = torch.exp(thought_log_probs)
|
1137 |
-
|
1138 |
-
policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
|
1139 |
-
|
1140 |
-
return policy_loss
|
1141 |
-
|
1142 |
def get_input_embeddings(self):
|
1143 |
return self.model.embed_tokens
|
1144 |
|
@@ -1214,13 +1208,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1214 |
use_cache=use_cache,
|
1215 |
output_attentions=output_attentions,
|
1216 |
output_hidden_states=output_hidden_states,
|
1217 |
-
return_dict=True,
|
1218 |
)
|
1219 |
-
|
1220 |
hidden_states = outputs.last_hidden_state
|
1221 |
logits = self.lm_head(hidden_states)
|
1222 |
|
1223 |
-
|
1224 |
thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
|
1225 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
1226 |
|
@@ -1230,7 +1222,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1230 |
# Mix base and thought logits
|
1231 |
mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
|
1232 |
mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
|
1233 |
-
|
1234 |
loss = None
|
1235 |
if labels is not None:
|
1236 |
# Shift so that tokens < n predict n
|
|
|
1098 |
self.remove_negative_rewards = True
|
1099 |
self.post_init()
|
1100 |
|
1101 |
+
def calculate_policy_loss(self, thoughts, rewards):
|
1102 |
+
thought_log_probs = []
|
1103 |
+
for thought in thoughts:
|
1104 |
+
thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
|
1105 |
+
thought_log_probs.append(thought_log_prob)
|
1106 |
+
|
1107 |
+
thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
|
1108 |
+
thought_probs = torch.exp(thought_log_probs)
|
1109 |
+
|
1110 |
+
policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
|
1111 |
+
|
1112 |
+
return policy_loss
|
1113 |
+
|
1114 |
def _generate_thoughts(self, hidden_states, max_length):
|
1115 |
batch_size = hidden_states.size(0)
|
1116 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
1117 |
thought_embeddings = []
|
1118 |
|
|
|
|
|
|
|
|
|
1119 |
for i in range(self.config.max_thoughts):
|
1120 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
1121 |
+
thought_outputs = self.generate(
|
1122 |
input_ids=thought_input_ids,
|
1123 |
max_length=max_length,
|
1124 |
do_sample=True,
|
|
|
1133 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
1134 |
return thought_ids, thought_embeddings
|
1135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1136 |
def get_input_embeddings(self):
|
1137 |
return self.model.embed_tokens
|
1138 |
|
|
|
1208 |
use_cache=use_cache,
|
1209 |
output_attentions=output_attentions,
|
1210 |
output_hidden_states=output_hidden_states,
|
1211 |
+
return_dict=True,
|
1212 |
)
|
|
|
1213 |
hidden_states = outputs.last_hidden_state
|
1214 |
logits = self.lm_head(hidden_states)
|
1215 |
|
|
|
1216 |
thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
|
1217 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
1218 |
|
|
|
1222 |
# Mix base and thought logits
|
1223 |
mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
|
1224 |
mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
|
1225 |
+
|
1226 |
loss = None
|
1227 |
if labels is not None:
|
1228 |
# Shift so that tokens < n predict n
|