H2OTest / llm_studio /src /models /text_rlhf_language_modeling_model.py
elineve's picture
Upload 301 files
07423df
raw
history blame
4.26 kB
import logging
from typing import Any, Dict
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
generate,
prepare_lora,
)
logger = logging.getLogger(__name__)
class ValueHead(nn.Module):
"""
The ValueHead class implements a head for GPT2 that returns a scalar for each
output token.
Based on the implementation of trl library:
https://github.com/lvwerra/trl/blob/main/trl/models/modeling_value_head.py
"""
def __init__(self, config):
super().__init__()
if not hasattr(config, "summary_dropout_prob"):
summary_dropout_prob = 0.1
else:
summary_dropout_prob = config.summary_dropout_prob
self.dropout = (
nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
)
# some models such as OPT have a projection layer before the word embeddings
# e.g. OPT-350m
if hasattr(config, "word_embed_proj_dim"):
hidden_size = config.word_embed_proj_dim
else:
hidden_size = config.hidden_size
self.summary = nn.Linear(hidden_size, 1)
def forward(self, hidden_states):
output = self.dropout(hidden_states)
# For now force upcast in fp32 if needed. Let's keep the
# output in fp32 for numerical stability.
if output.dtype != self.summary.weight.dtype:
output = output.to(self.summary.weight.dtype)
output = self.summary(output)
return output
class Model(nn.Module):
"""
Model for causal language modeling problem type.
"""
def __init__(self, cfg: Any):
"""
Args:
cfg: config with all the hyperparameters
"""
super(Model, self).__init__()
self.cfg = cfg
assert cfg.training.lora, "LoRA must be True for RLHF"
self.backbone, self.backbone_config = create_nlp_backbone(
cfg, model_class=AutoModelForCausalLM
)
self.backbone = prepare_lora(cfg=self.cfg, backbone=self.backbone)
if self.cfg.prediction.metric == "Perplexity":
self.perplexity = Perplexity(self.cfg, reduce=False)
self.value_head = ValueHead(self.backbone_config)
self.value_head.summary.bias.data.zero_()
def forward(
self,
batch: Dict,
padding: bool = True,
) -> Dict:
# disable cache if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = False
outputs: Dict = {}
mask_key = "attention_mask"
pad_keys = [
"input_ids",
"attention_mask",
"special_tokens_mask",
"labels",
]
if padding:
batch = batch_padding(
self.cfg,
batch,
self.training,
mask_key=mask_key,
pad_keys=pad_keys,
)
output = self.backbone(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
output_hidden_states=True,
)
if self.cfg.prediction.metric == "Perplexity" and not self.training:
outputs["perplexity"] = self.perplexity(output.logits, batch["labels"])
if self.training:
last_hidden_state = output.hidden_states[-1]
# force upcast in fp32 if logits are in half-precision
if output.logits.dtype != torch.float32:
output.logits = output.logits.float()
outputs["logits"] = output.logits
outputs["value"] = self.value_head(last_hidden_state).squeeze(-1)
# enable cache again if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = True
return outputs
def generate(self, batch: Dict, cfg: Any, streamer=None):
return generate(self.backbone, batch, cfg, streamer)