Commit
•
8e79133
1
Parent(s):
e0dce68
add align tensors
Browse files- modeling_custom.py +5 -0
modeling_custom.py
CHANGED
@@ -96,6 +96,9 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
96 |
temperature=config_dict.get("gating_temperature", 10),
|
97 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
98 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
|
|
|
|
|
|
99 |
|
100 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
101 |
def forward(
|
@@ -153,6 +156,8 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
153 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
154 |
gating_output = self.gating(prompt_embedding)
|
155 |
|
|
|
|
|
156 |
rewards_adjusted = rewards @ self.reward_transform_matrix
|
157 |
score = torch.sum(gating_output * rewards_adjusted, dim=1)
|
158 |
|
|
|
96 |
temperature=config_dict.get("gating_temperature", 10),
|
97 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
98 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
99 |
+
def align_tensor_devices(self, *tensors):
|
100 |
+
target_device = tensors[0].device
|
101 |
+
return [tensor.to(target_device) if tensor.device != target_device else tensor for tensor in tensors]
|
102 |
|
103 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
104 |
def forward(
|
|
|
156 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
157 |
gating_output = self.gating(prompt_embedding)
|
158 |
|
159 |
+
rewards, self.reward_transform_matrix = self.align_tensor_devices(rewards, self.reward_transform_matrix)
|
160 |
+
|
161 |
rewards_adjusted = rewards @ self.reward_transform_matrix
|
162 |
score = torch.sum(gating_output * rewards_adjusted, dim=1)
|
163 |
|