gabrielmbmb HF staff commited on
Commit
3f9f573
·
unverified ·
1 Parent(s): 89047a4

Update device

Browse files
Files changed (1) hide show
  1. modeling_custom.py +3 -3
modeling_custom.py CHANGED
@@ -140,11 +140,11 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
140
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
141
  sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
142
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
143
- sequence_lengths = sequence_lengths.to("cuda")
144
  else:
145
  sequence_lengths = -1
146
 
147
- dummy_iterator = torch.arange(batch_size, device=tokens_hidden_states.device)
148
  hidden_states = tokens_hidden_states[dummy_iterator, sequence_lengths]
149
  assert hidden_states.shape == (batch_size, self.config.hidden_size)
150
  rewards = self.regression_layer(hidden_states)
@@ -163,4 +163,4 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
163
  gating_output=gating_output,
164
  score=score,
165
  logits=score,
166
- )
 
140
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
141
  sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
142
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
143
+ sequence_lengths = sequence_lengths.to(self.device)
144
  else:
145
  sequence_lengths = -1
146
 
147
+ dummy_iterator = torch.arange(batch_size, device=self.device)
148
  hidden_states = tokens_hidden_states[dummy_iterator, sequence_lengths]
149
  assert hidden_states.shape == (batch_size, self.config.hidden_size)
150
  rewards = self.regression_layer(hidden_states)
 
163
  gating_output=gating_output,
164
  score=score,
165
  logits=score,
166
+ )