stsb-roberta-base-v2 / modeling_roberta.py
michaelfeil's picture
Upload 14 files
a0c1f55 verified
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from fms.models.hf.lm_head_mixins import (
MaskedLMHeadMixin,
SequenceClassificationLMHeadMixin,
)
from fms.models.hf.modeling_hf_adapter import HFEncoder, HFEncoderModelArchitecture
from fms.models.roberta import RoBERTa, RoBERTaConfig, RoBERTaHeadless
class HFAdaptedRoBERTaConfig(PretrainedConfig):
model_type = "hf_adapted_roberta"
attribute_map = {
"vocab_size": "src_vocab_size",
"hidden_size": "emb_dim",
"num_attention_heads": "nheads",
"num_hidden_layers": "nlayers",
"tie_word_embeddings": "tie_heads",
}
def __init__(
self,
src_vocab_size=None,
emb_dim=None,
nheads=12,
nlayers=12,
max_pos=512,
pad_token_id=1,
hidden_grow_factor=4,
activation_fn="gelu",
classifier_activation_fn="tanh",
p_dropout=0.1,
classifier_dropout=0.1,
use_cache=True,
num_labels=1,
norm_eps=1e-12,
tie_heads=False,
**kwargs,
):
self.src_vocab_size = src_vocab_size
self.emb_dim = emb_dim
self.nheads = nheads
self.nlayers = nlayers
self.max_pos = max_pos
self.hidden_grow_factor = hidden_grow_factor
if activation_fn.lower() not in ["gelu", "relu", "mish", "swish"]:
raise ValueError(
"activation function must be one of gelu, relu, mish, swish"
)
self.activation_fn = activation_fn
self.p_dropout = p_dropout
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.norm_eps = norm_eps
self.classifier_activation_fn = classifier_activation_fn
self.tie_heads = tie_heads
super().__init__(
pad_token_id=pad_token_id,
num_labels=num_labels,
tie_word_embeddings=kwargs.pop("tie_word_embeddings", tie_heads),
**kwargs,
)
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, **kwargs
) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def from_fms_config(cls, config: RoBERTaConfig, **hf_kwargs):
config_dict = config.as_dict()
config_dict["pad_token_id"] = config_dict.pop("pad_id")
return cls.from_dict(config_dict, **hf_kwargs)
class HFAdaptedRoBERTaEncoder(HFEncoder):
"""Adapter for the Roberta Encoder"""
def __init__(self, model: RoBERTaHeadless, config: PretrainedConfig):
super().__init__(model, config, attention_mask_dim=3)
def _adapt(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
*args,
**kwargs,
) -> BaseModelOutputWithPastAndCrossAttentions:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=self.model(
x=input_ids, mask=attention_mask, position_ids=position_ids
)
)
class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture):
# attributes required by HF
config_class = HFAdaptedRoBERTaConfig
base_model_prefix = "hf_adapted_roberta"
def __init__(
self,
config: PretrainedConfig,
encoder: Optional[RoBERTaHeadless] = None,
embedding: Optional[nn.Module] = None,
*args,
**kwargs,
):
# in the case we have not yet received the encoder/decoder/embedding, initialize it here
if encoder is None or embedding is None:
params = config.to_dict()
model = RoBERTa(pad_id=params.pop("pad_token_id"), **params)
encoder = model.base_model if encoder is None else encoder
embedding = model.base_model.embedding if embedding is None else embedding
# these are now huggingface compatible
encoder = HFAdaptedRoBERTaEncoder(encoder, config)
super().__init__(encoder, embedding, config, *args, **kwargs)
class HFAdaptedRoBERTaForMaskedLM(MaskedLMHeadMixin, HFAdaptedRoBERTaHeadless):
def __init__(self, config: HFAdaptedRoBERTaConfig, *args, **kwargs):
super().__init__(
config=config,
activation_fn=config.activation_fn,
norm_eps=config.norm_eps,
*args,
**kwargs,
)
@classmethod
def _hf_model_from_fms(
cls, model: RoBERTa, config: HFAdaptedRoBERTaConfig
) -> "HFAdaptedRoBERTaForMaskedLM":
return cls(
config=config,
encoder=model.base_model,
embedding=model.base_model.embedding,
lm_head=model.classification_head,
)
class HFAdaptedRoBERTaForSequenceClassification(
SequenceClassificationLMHeadMixin, HFAdaptedRoBERTaHeadless
):
def __init__(
self,
config: HFAdaptedRoBERTaConfig,
encoder: Optional[nn.Module] = None,
embedding: Optional[nn.Module] = None,
*args,
**kwargs,
):
super().__init__(
config=config,
classifier_activation_fn=config.classifier_activation_fn,
classifier_dropout=config.classifier_dropout,
encoder=encoder,
embedding=embedding,
*args,
**kwargs,
)
@classmethod
def _hf_model_from_fms(
cls, model: RoBERTa, config: HFAdaptedRoBERTaConfig
) -> "HFAdaptedRoBERTaForSequenceClassification":
return cls(
config=config,
encoder=model.base_model,
embedding=model.base_model.embedding,
)