Crystalcareai commited on
Commit
f088fe4
1 Parent(s): 793b4ee

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -2
modeling_quiet.py CHANGED
@@ -1100,10 +1100,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1100
 
1101
  def _generate_thoughts(self, hidden_states, max_length):
1102
  batch_size = hidden_states.size(0)
1103
- thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
  thought_embeddings = []
1105
 
1106
- for i in range(self.config.num_thoughts):
1107
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1108
  thought_outputs = self.model.generate(
1109
  input_ids=thought_input_ids,
@@ -1120,6 +1120,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1120
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1121
  return thought_ids, thought_embeddings
1122
 
 
1123
  def calculate_policy_loss(self, thoughts, rewards):
1124
  thought_log_probs = []
1125
  for thought in thoughts:
 
1100
 
1101
  def _generate_thoughts(self, hidden_states, max_length):
1102
  batch_size = hidden_states.size(0)
1103
+ thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
  thought_embeddings = []
1105
 
1106
+ for i in range(self.config.max_thoughts):
1107
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1108
  thought_outputs = self.model.generate(
1109
  input_ids=thought_input_ids,
 
1120
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1121
  return thought_ids, thought_embeddings
1122
 
1123
+
1124
  def calculate_policy_loss(self, thoughts, rewards):
1125
  thought_log_probs = []
1126
  for thought in thoughts: