Update modeling_custom.py

#14
by gabrielmbmb HF staff - opened
Files changed (1) hide show
  1. modeling_custom.py +6 -1
modeling_custom.py CHANGED
@@ -96,6 +96,9 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
96
  temperature=config_dict.get("gating_temperature", 10),
97
  hidden_dim=config_dict.get("gating_hidden_dim", 1024),
98
  n_hidden=config_dict.get("gating_n_hidden", 3))
 
 
 
99
 
100
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
101
  def forward(
@@ -153,6 +156,8 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
153
  prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
154
  gating_output = self.gating(prompt_embedding)
155
 
 
 
156
  rewards_adjusted = rewards @ self.reward_transform_matrix
157
  score = torch.sum(gating_output * rewards_adjusted, dim=1)
158
 
@@ -163,4 +168,4 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
163
  gating_output=gating_output,
164
  score=score,
165
  logits=score,
166
- )
 
96
  temperature=config_dict.get("gating_temperature", 10),
97
  hidden_dim=config_dict.get("gating_hidden_dim", 1024),
98
  n_hidden=config_dict.get("gating_n_hidden", 3))
99
+ def align_tensor_devices(self, *tensors):
100
+ target_device = tensors[0].device
101
+ return [tensor.to(target_device) if tensor.device != target_device else tensor for tensor in tensors]
102
 
103
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
104
  def forward(
 
156
  prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
157
  gating_output = self.gating(prompt_embedding)
158
 
159
+ rewards, self.reward_transform_matrix = self.align_tensor_devices(rewards, self.reward_transform_matrix)
160
+
161
  rewards_adjusted = rewards @ self.reward_transform_matrix
162
  score = torch.sum(gating_output * rewards_adjusted, dim=1)
163
 
 
168
  gating_output=gating_output,
169
  score=score,
170
  logits=score,
171
+ )