Ray2333 commited on
Commit
3bcb014
1 Parent(s): dc744d1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -1
model.py CHANGED
@@ -132,7 +132,12 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
132
  last_hidden_state = last_hidden_state.to(self.v_head.summary[0].weight.device)
133
 
134
  # use the last token value as reward
135
- last_index = attention_mask.sum(dim=-1) - 1
 
 
 
 
 
136
  value = self.v_head(last_hidden_state).squeeze(-1)[torch.arange(len(last_hidden_state)), last_index]
137
 
138
  # force upcast in fp32 if logits are in half-precision
 
132
  last_hidden_state = last_hidden_state.to(self.v_head.summary[0].weight.device)
133
 
134
  # use the last token value as reward
135
+ if torch.any(attention_mask[:, 0] == 0):
136
+ # left padding
137
+ last_index = attention_mask.shape[-1] - 1
138
+ else:
139
+ # right padding
140
+ last_index = attention_mask.sum(dim=-1) - 1
141
  value = self.v_head(last_hidden_state).squeeze(-1)[torch.arange(len(last_hidden_state)), last_index]
142
 
143
  # force upcast in fp32 if logits are in half-precision