|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PretrainedConfig |
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions |
|
|
|
from fms.models.hf.lm_head_mixins import ( |
|
MaskedLMHeadMixin, |
|
SequenceClassificationLMHeadMixin, |
|
) |
|
from fms.models.hf.modeling_hf_adapter import HFEncoder, HFEncoderModelArchitecture |
|
from fms.models.roberta import RoBERTa, RoBERTaConfig, RoBERTaHeadless |
|
|
|
|
|
class HFAdaptedRoBERTaConfig(PretrainedConfig): |
|
model_type = "hf_adapted_roberta" |
|
|
|
attribute_map = { |
|
"vocab_size": "src_vocab_size", |
|
"hidden_size": "emb_dim", |
|
"num_attention_heads": "nheads", |
|
"num_hidden_layers": "nlayers", |
|
"tie_word_embeddings": "tie_heads", |
|
} |
|
|
|
def __init__( |
|
self, |
|
src_vocab_size=None, |
|
emb_dim=None, |
|
nheads=12, |
|
nlayers=12, |
|
max_pos=512, |
|
pad_token_id=1, |
|
hidden_grow_factor=4, |
|
activation_fn="gelu", |
|
classifier_activation_fn="tanh", |
|
p_dropout=0.1, |
|
classifier_dropout=0.1, |
|
use_cache=True, |
|
num_labels=1, |
|
norm_eps=1e-12, |
|
tie_heads=False, |
|
**kwargs, |
|
): |
|
self.src_vocab_size = src_vocab_size |
|
self.emb_dim = emb_dim |
|
self.nheads = nheads |
|
self.nlayers = nlayers |
|
self.max_pos = max_pos |
|
self.hidden_grow_factor = hidden_grow_factor |
|
if activation_fn.lower() not in ["gelu", "relu", "mish", "swish"]: |
|
raise ValueError( |
|
"activation function must be one of gelu, relu, mish, swish" |
|
) |
|
self.activation_fn = activation_fn |
|
self.p_dropout = p_dropout |
|
self.classifier_dropout = classifier_dropout |
|
self.use_cache = use_cache |
|
self.norm_eps = norm_eps |
|
self.classifier_activation_fn = classifier_activation_fn |
|
self.tie_heads = tie_heads |
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
num_labels=num_labels, |
|
tie_word_embeddings=kwargs.pop("tie_word_embeddings", tie_heads), |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, pretrained_model_name_or_path, **kwargs |
|
) -> "PretrainedConfig": |
|
config_dict, kwargs = cls.get_config_dict( |
|
pretrained_model_name_or_path, **kwargs |
|
) |
|
|
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
@classmethod |
|
def from_fms_config(cls, config: RoBERTaConfig, **hf_kwargs): |
|
config_dict = config.as_dict() |
|
config_dict["pad_token_id"] = config_dict.pop("pad_id") |
|
return cls.from_dict(config_dict, **hf_kwargs) |
|
|
|
|
|
class HFAdaptedRoBERTaEncoder(HFEncoder): |
|
"""Adapter for the Roberta Encoder""" |
|
|
|
def __init__(self, model: RoBERTaHeadless, config: PretrainedConfig): |
|
super().__init__(model, config, attention_mask_dim=3) |
|
|
|
def _adapt( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> BaseModelOutputWithPastAndCrossAttentions: |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=self.model( |
|
x=input_ids, mask=attention_mask, position_ids=position_ids |
|
) |
|
) |
|
|
|
|
|
class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): |
|
|
|
config_class = HFAdaptedRoBERTaConfig |
|
base_model_prefix = "hf_adapted_roberta" |
|
|
|
def __init__( |
|
self, |
|
config: PretrainedConfig, |
|
encoder: Optional[RoBERTaHeadless] = None, |
|
embedding: Optional[nn.Module] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
|
|
if encoder is None or embedding is None: |
|
params = config.to_dict() |
|
model = RoBERTa(pad_id=params.pop("pad_token_id"), **params) |
|
encoder = model.base_model if encoder is None else encoder |
|
embedding = model.base_model.embedding if embedding is None else embedding |
|
|
|
|
|
encoder = HFAdaptedRoBERTaEncoder(encoder, config) |
|
super().__init__(encoder, embedding, config, *args, **kwargs) |
|
|
|
|
|
class HFAdaptedRoBERTaForMaskedLM(MaskedLMHeadMixin, HFAdaptedRoBERTaHeadless): |
|
def __init__(self, config: HFAdaptedRoBERTaConfig, *args, **kwargs): |
|
super().__init__( |
|
config=config, |
|
activation_fn=config.activation_fn, |
|
norm_eps=config.norm_eps, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def _hf_model_from_fms( |
|
cls, model: RoBERTa, config: HFAdaptedRoBERTaConfig |
|
) -> "HFAdaptedRoBERTaForMaskedLM": |
|
return cls( |
|
config=config, |
|
encoder=model.base_model, |
|
embedding=model.base_model.embedding, |
|
lm_head=model.classification_head, |
|
) |
|
|
|
|
|
class HFAdaptedRoBERTaForSequenceClassification( |
|
SequenceClassificationLMHeadMixin, HFAdaptedRoBERTaHeadless |
|
): |
|
def __init__( |
|
self, |
|
config: HFAdaptedRoBERTaConfig, |
|
encoder: Optional[nn.Module] = None, |
|
embedding: Optional[nn.Module] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
config=config, |
|
classifier_activation_fn=config.classifier_activation_fn, |
|
classifier_dropout=config.classifier_dropout, |
|
encoder=encoder, |
|
embedding=embedding, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def _hf_model_from_fms( |
|
cls, model: RoBERTa, config: HFAdaptedRoBERTaConfig |
|
) -> "HFAdaptedRoBERTaForSequenceClassification": |
|
return cls( |
|
config=config, |
|
encoder=model.base_model, |
|
embedding=model.base_model.embedding, |
|
) |