kirigayahitsugi
commited on
Commit
•
9b64f2c
1
Parent(s):
a2cadc2
Update README.md
Browse files
README.md
CHANGED
@@ -204,19 +204,35 @@ class GPMPipeline:
|
|
204 |
with torch.no_grad():
|
205 |
rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
|
206 |
|
|
|
207 |
if return_prompt:
|
208 |
-
# Compute prompt hidden states
|
209 |
prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
|
222 |
prompt_text = "Describe the importance of reading books in today's digital age."
|
|
|
204 |
with torch.no_grad():
|
205 |
rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
|
206 |
|
207 |
+
chosen_response_len_list = []
|
208 |
if return_prompt:
|
|
|
209 |
prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
|
210 |
+
for i in range(len(input_texts)):
|
211 |
+
prompt_token = self.tokenizer(
|
212 |
+
prompt_texts[i],
|
213 |
+
max_length=self.max_length,
|
214 |
+
padding=False,
|
215 |
+
truncation=True,
|
216 |
+
return_tensors="pt",
|
217 |
+
)
|
218 |
+
chosen_token = self.tokenizer(
|
219 |
+
input_texts[i],
|
220 |
+
max_length=self.max_length,
|
221 |
+
padding=False,
|
222 |
+
truncation=True,
|
223 |
+
return_tensors="pt",
|
224 |
+
)
|
225 |
+
chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
|
226 |
+
chosen_response_len_list.append(chosen_response_len)
|
227 |
+
chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
|
228 |
+
if return_prompt:
|
229 |
+
chosen_last_hidden_states = outputs["last_hidden_state"]
|
230 |
+
prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
|
231 |
+
prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
|
232 |
+
prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
|
233 |
+
return rewards, prompt_hidden_state
|
234 |
+
else:
|
235 |
+
return rewards
|
236 |
|
237 |
|
238 |
prompt_text = "Describe the importance of reading books in today's digital age."
|