Safetensors
gemma
kirigayahitsugi commited on
Commit
9b64f2c
1 Parent(s): a2cadc2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -11
README.md CHANGED
@@ -204,19 +204,35 @@ class GPMPipeline:
204
  with torch.no_grad():
205
  rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
206
 
 
207
  if return_prompt:
208
- # Compute prompt hidden states
209
  prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
210
- prompt_lengths = [len(self.tokenizer(prompt_text, padding=False, return_tensors="pt")["input_ids"][0]) for prompt_text in prompt_texts]
211
- prompt_lengths = torch.tensor(prompt_lengths, device=self.device)
212
- prompt_end_indices = prompt_lengths - 1
213
-
214
- last_hidden_states = outputs.last_hidden_state
215
- prompt_hidden_states = last_hidden_states[torch.arange(len(samples)), prompt_end_indices, :]
216
-
217
- return rewards, prompt_hidden_states
218
-
219
- return rewards
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
 
222
  prompt_text = "Describe the importance of reading books in today's digital age."
 
204
  with torch.no_grad():
205
  rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
206
 
207
+ chosen_response_len_list = []
208
  if return_prompt:
 
209
  prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
210
+ for i in range(len(input_texts)):
211
+ prompt_token = self.tokenizer(
212
+ prompt_texts[i],
213
+ max_length=self.max_length,
214
+ padding=False,
215
+ truncation=True,
216
+ return_tensors="pt",
217
+ )
218
+ chosen_token = self.tokenizer(
219
+ input_texts[i],
220
+ max_length=self.max_length,
221
+ padding=False,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
226
+ chosen_response_len_list.append(chosen_response_len)
227
+ chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
228
+ if return_prompt:
229
+ chosen_last_hidden_states = outputs["last_hidden_state"]
230
+ prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
231
+ prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
232
+ prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
233
+ return rewards, prompt_hidden_state
234
+ else:
235
+ return rewards
236
 
237
 
238
  prompt_text = "Describe the importance of reading books in today's digital age."