Crystalcareai commited on
Commit
7b59ceb
1 Parent(s): 0d78878

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +16 -24
modeling_quiet.py CHANGED
@@ -1098,18 +1098,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
1098
  self.remove_negative_rewards = True
1099
  self.post_init()
1100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1101
  def _generate_thoughts(self, hidden_states, max_length):
1102
  batch_size = hidden_states.size(0)
1103
  thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
  thought_embeddings = []
1105
 
1106
- # Create an instance of QuietForCausalLM using the current model's configuration
1107
- causal_lm_model = QuietForCausalLM(self.config)
1108
- causal_lm_model.eval() # Set the model to evaluation mode
1109
-
1110
  for i in range(self.config.max_thoughts):
1111
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1112
- thought_outputs = causal_lm_model.generate(
1113
  input_ids=thought_input_ids,
1114
  max_length=max_length,
1115
  do_sample=True,
@@ -1124,21 +1133,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1124
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1125
  return thought_ids, thought_embeddings
1126
 
1127
-
1128
-
1129
- def calculate_policy_loss(self, thoughts, rewards):
1130
- thought_log_probs = []
1131
- for thought in thoughts:
1132
- thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
1133
- thought_log_probs.append(thought_log_prob)
1134
-
1135
- thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
1136
- thought_probs = torch.exp(thought_log_probs)
1137
-
1138
- policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
1139
-
1140
- return policy_loss
1141
-
1142
  def get_input_embeddings(self):
1143
  return self.model.embed_tokens
1144
 
@@ -1214,13 +1208,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1214
  use_cache=use_cache,
1215
  output_attentions=output_attentions,
1216
  output_hidden_states=output_hidden_states,
1217
- return_dict=True, # Set return_dict=True
1218
  )
1219
-
1220
  hidden_states = outputs.last_hidden_state
1221
  logits = self.lm_head(hidden_states)
1222
 
1223
-
1224
  thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
1225
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
1226
 
@@ -1230,7 +1222,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1230
  # Mix base and thought logits
1231
  mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
1232
  mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
1233
-
1234
  loss = None
1235
  if labels is not None:
1236
  # Shift so that tokens < n predict n
 
1098
  self.remove_negative_rewards = True
1099
  self.post_init()
1100
 
1101
+ def calculate_policy_loss(self, thoughts, rewards):
1102
+ thought_log_probs = []
1103
+ for thought in thoughts:
1104
+ thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
1105
+ thought_log_probs.append(thought_log_prob)
1106
+
1107
+ thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
1108
+ thought_probs = torch.exp(thought_log_probs)
1109
+
1110
+ policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
1111
+
1112
+ return policy_loss
1113
+
1114
  def _generate_thoughts(self, hidden_states, max_length):
1115
  batch_size = hidden_states.size(0)
1116
  thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1117
  thought_embeddings = []
1118
 
 
 
 
 
1119
  for i in range(self.config.max_thoughts):
1120
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1121
+ thought_outputs = self.generate(
1122
  input_ids=thought_input_ids,
1123
  max_length=max_length,
1124
  do_sample=True,
 
1133
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1134
  return thought_ids, thought_embeddings
1135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1136
  def get_input_embeddings(self):
1137
  return self.model.embed_tokens
1138
 
 
1208
  use_cache=use_cache,
1209
  output_attentions=output_attentions,
1210
  output_hidden_states=output_hidden_states,
1211
+ return_dict=True,
1212
  )
 
1213
  hidden_states = outputs.last_hidden_state
1214
  logits = self.lm_head(hidden_states)
1215
 
 
1216
  thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
1217
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
1218
 
 
1222
  # Mix base and thought logits
1223
  mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
1224
  mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
1225
+
1226
  loss = None
1227
  if labels is not None:
1228
  # Shift so that tokens < n predict n