Crystalcareai commited on
Commit
0d78878
·
verified ·
1 Parent(s): f088fe4

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +7 -2
modeling_quiet.py CHANGED
@@ -1103,9 +1103,13 @@ class QuietForCausalLM(QuietPreTrainedModel):
1103
  thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
  thought_embeddings = []
1105
 
 
 
 
 
1106
  for i in range(self.config.max_thoughts):
1107
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1108
- thought_outputs = self.model.generate(
1109
  input_ids=thought_input_ids,
1110
  max_length=max_length,
1111
  do_sample=True,
@@ -1115,12 +1119,13 @@ class QuietForCausalLM(QuietPreTrainedModel):
1115
  eos_token_id=self.config.eos_token_id,
1116
  )
1117
  thought_ids[:, i, :] = thought_outputs
1118
- thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
1119
 
1120
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1121
  return thought_ids, thought_embeddings
1122
 
1123
 
 
1124
  def calculate_policy_loss(self, thoughts, rewards):
1125
  thought_log_probs = []
1126
  for thought in thoughts:
 
1103
  thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
  thought_embeddings = []
1105
 
1106
+ # Create an instance of QuietForCausalLM using the current model's configuration
1107
+ causal_lm_model = QuietForCausalLM(self.config)
1108
+ causal_lm_model.eval() # Set the model to evaluation mode
1109
+
1110
  for i in range(self.config.max_thoughts):
1111
  thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1112
+ thought_outputs = causal_lm_model.generate(
1113
  input_ids=thought_input_ids,
1114
  max_length=max_length,
1115
  do_sample=True,
 
1119
  eos_token_id=self.config.eos_token_id,
1120
  )
1121
  thought_ids[:, i, :] = thought_outputs
1122
+ thought_embeddings.append(self.get_input_embeddings()(thought_outputs))
1123
 
1124
  thought_embeddings = torch.stack(thought_embeddings, dim=1)
1125
  return thought_ids, thought_embeddings
1126
 
1127
 
1128
+
1129
  def calculate_policy_loss(self, thoughts, rewards):
1130
  thought_log_probs = []
1131
  for thought in thoughts: