Crystalcareai commited on
Commit
2535770
1 Parent(s): b087ddf

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +22 -26
modeling_quiet.py CHANGED
@@ -928,30 +928,6 @@ class QuietModel(QuietPreTrainedModel):
928
  def set_input_embeddings(self, value):
929
  self.embed_tokens = value
930
 
931
- def _generate_thoughts(self, hidden_states, max_length):
932
- batch_size = hidden_states.size(0)
933
- thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
934
- thought_embeddings = []
935
-
936
- for i in range(self.config.num_thoughts):
937
- thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
938
- thought_outputs = self.model.generate(
939
- input_ids=thought_input_ids,
940
- max_length=max_length,
941
- do_sample=True,
942
- top_k=50,
943
- top_p=0.95,
944
- pad_token_id=self.config.pad_token_id,
945
- eos_token_id=self.config.eos_token_id,
946
- )
947
- thought_ids[:, i, :] = thought_outputs
948
- thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
949
-
950
- thought_embeddings = torch.stack(thought_embeddings, dim=1)
951
- return thought_ids, thought_embeddings
952
-
953
-
954
-
955
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
956
  def forward(
957
  self,
@@ -1116,14 +1092,34 @@ class QuietForCausalLM(QuietPreTrainedModel):
1116
  nn.ReLU(),
1117
  nn.Linear(config.hidden_size, 1),
1118
  )
1119
-
1120
  self.max_thoughts = config.max_thoughts
1121
  self.thought_length = config.thought_length
1122
  self.use_policy_loss = True
1123
  self.remove_negative_rewards = True
1124
-
1125
  self.post_init()
1126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
  def calculate_policy_loss(self, thoughts, rewards):
1128
  thought_log_probs = []
1129
  for thought in thoughts:
 
928
  def set_input_embeddings(self, value):
929
  self.embed_tokens = value
930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
932
  def forward(
933
  self,
 
1092
  nn.ReLU(),
1093
  nn.Linear(config.hidden_size, 1),
1094
  )
 
1095
  self.max_thoughts = config.max_thoughts
1096
  self.thought_length = config.thought_length
1097
  self.use_policy_loss = True
1098
  self.remove_negative_rewards = True
 
1099
  self.post_init()
1100
 
1101
+ def _generate_thoughts(self, hidden_states, max_length):
1102
+ batch_size = hidden_states.size(0)
1103
+ thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
1104
+ thought_embeddings = []
1105
+
1106
+ for i in range(self.config.num_thoughts):
1107
+ thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
1108
+ thought_outputs = self.model.generate(
1109
+ input_ids=thought_input_ids,
1110
+ max_length=max_length,
1111
+ do_sample=True,
1112
+ top_k=50,
1113
+ top_p=0.95,
1114
+ pad_token_id=self.config.pad_token_id,
1115
+ eos_token_id=self.config.eos_token_id,
1116
+ )
1117
+ thought_ids[:, i, :] = thought_outputs
1118
+ thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
1119
+
1120
+ thought_embeddings = torch.stack(thought_embeddings, dim=1)
1121
+ return thought_ids, thought_embeddings
1122
+
1123
  def calculate_policy_loss(self, thoughts, rewards):
1124
  thought_log_probs = []
1125
  for thought in thoughts: