Crystalcareai
commited on
Commit
•
2535770
1
Parent(s):
b087ddf
Update modeling_quiet.py
Browse files- modeling_quiet.py +22 -26
modeling_quiet.py
CHANGED
@@ -928,30 +928,6 @@ class QuietModel(QuietPreTrainedModel):
|
|
928 |
def set_input_embeddings(self, value):
|
929 |
self.embed_tokens = value
|
930 |
|
931 |
-
def _generate_thoughts(self, hidden_states, max_length):
|
932 |
-
batch_size = hidden_states.size(0)
|
933 |
-
thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
934 |
-
thought_embeddings = []
|
935 |
-
|
936 |
-
for i in range(self.config.num_thoughts):
|
937 |
-
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
938 |
-
thought_outputs = self.model.generate(
|
939 |
-
input_ids=thought_input_ids,
|
940 |
-
max_length=max_length,
|
941 |
-
do_sample=True,
|
942 |
-
top_k=50,
|
943 |
-
top_p=0.95,
|
944 |
-
pad_token_id=self.config.pad_token_id,
|
945 |
-
eos_token_id=self.config.eos_token_id,
|
946 |
-
)
|
947 |
-
thought_ids[:, i, :] = thought_outputs
|
948 |
-
thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
|
949 |
-
|
950 |
-
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
951 |
-
return thought_ids, thought_embeddings
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
956 |
def forward(
|
957 |
self,
|
@@ -1116,14 +1092,34 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1116 |
nn.ReLU(),
|
1117 |
nn.Linear(config.hidden_size, 1),
|
1118 |
)
|
1119 |
-
|
1120 |
self.max_thoughts = config.max_thoughts
|
1121 |
self.thought_length = config.thought_length
|
1122 |
self.use_policy_loss = True
|
1123 |
self.remove_negative_rewards = True
|
1124 |
-
|
1125 |
self.post_init()
|
1126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1127 |
def calculate_policy_loss(self, thoughts, rewards):
|
1128 |
thought_log_probs = []
|
1129 |
for thought in thoughts:
|
|
|
928 |
def set_input_embeddings(self, value):
|
929 |
self.embed_tokens = value
|
930 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
932 |
def forward(
|
933 |
self,
|
|
|
1092 |
nn.ReLU(),
|
1093 |
nn.Linear(config.hidden_size, 1),
|
1094 |
)
|
|
|
1095 |
self.max_thoughts = config.max_thoughts
|
1096 |
self.thought_length = config.thought_length
|
1097 |
self.use_policy_loss = True
|
1098 |
self.remove_negative_rewards = True
|
|
|
1099 |
self.post_init()
|
1100 |
|
1101 |
+
def _generate_thoughts(self, hidden_states, max_length):
|
1102 |
+
batch_size = hidden_states.size(0)
|
1103 |
+
thought_ids = torch.zeros((batch_size, self.config.num_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
1104 |
+
thought_embeddings = []
|
1105 |
+
|
1106 |
+
for i in range(self.config.num_thoughts):
|
1107 |
+
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
1108 |
+
thought_outputs = self.model.generate(
|
1109 |
+
input_ids=thought_input_ids,
|
1110 |
+
max_length=max_length,
|
1111 |
+
do_sample=True,
|
1112 |
+
top_k=50,
|
1113 |
+
top_p=0.95,
|
1114 |
+
pad_token_id=self.config.pad_token_id,
|
1115 |
+
eos_token_id=self.config.eos_token_id,
|
1116 |
+
)
|
1117 |
+
thought_ids[:, i, :] = thought_outputs
|
1118 |
+
thought_embeddings.append(self.model.get_input_embeddings()(thought_outputs))
|
1119 |
+
|
1120 |
+
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
1121 |
+
return thought_ids, thought_embeddings
|
1122 |
+
|
1123 |
def calculate_policy_loss(self, thoughts, rewards):
|
1124 |
thought_log_probs = []
|
1125 |
for thought in thoughts:
|