Eurus-RM-7b / modeling_eurus_rm.py
lievan's picture
Enable flash_attention_2 support since the underlying Mistral model supports it (#3)
5ab7ef3 verified
raw
history blame
1.71 kB
from transformers import PreTrainedModel, MistralConfig, MistralModel
import torch.nn as nn
import torch
from typing import Optional, List
class EurusRewardModel(PreTrainedModel):
config_class = MistralConfig
_supports_flash_attn_2 = True
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False)
def forward( # args are the same as LlamaForCausalLM
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
rewards = self.regression_head(hidden_states).squeeze(-1)
ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
rewards = torch.gather(rewards, 1, ends)
return rewards