BartekSadlej's picture
Upload model
37d5787 verified
raw
history blame
8.29 kB
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
MaskedLMOutput,
BaseModelOutput,
SequenceClassifierOutput,
)
from enum import Enum
from .config import ILKTConfig
def cls_pooling(last_hidden_state, attention_mask):
return last_hidden_state[:, 0, :]
def create_head_blocks(
hidden_size: int,
n_dense: int,
use_batch_norm: bool,
use_layer_norm: bool,
dropout: float,
**kwargs,
) -> nn.Module:
blocks = []
for _ in range(n_dense):
blocks.append(nn.Linear(hidden_size, hidden_size))
if use_batch_norm:
blocks.append(nn.BatchNorm1d(hidden_size))
elif use_layer_norm:
blocks.append(nn.LayerNorm(hidden_size))
blocks.append(nn.ReLU())
if dropout > 0:
blocks.append(nn.Dropout(dropout))
return nn.Sequential(*blocks)
class SentenceEmbeddingHead(nn.Module):
def __init__(
self, backbone_hidden_size: int, embedding_head_config: Dict[str, Any]
):
super().__init__()
self.config = embedding_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **embedding_head_config),
]
)
def forward(
self, backbone_output: BaseModelOutput, attention_mask: torch.Tensor, **kwargs
) -> BaseModelOutputWithPooling:
if self.config["pool_type"] == "cls":
embeddings = cls_pooling(backbone_output.last_hidden_state, attention_mask)
else:
raise NotImplementedError(
f"Pooling type {self.config['pool_type']} not implemented"
)
if self.config["normalize_embeddings"]:
embeddings = nn.functional.normalize(embeddings, p=2, dim=-1)
return BaseModelOutputWithPooling(
last_hidden_state=backbone_output.last_hidden_state,
pooler_output=embeddings, # type: ignore
)
class MLMHead(nn.Module):
def __init__(
self,
backbone_hidden_size: int,
vocab_size: int,
mlm_head_config: Dict[str, Any],
):
super().__init__()
self.config = mlm_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **mlm_head_config),
nn.Linear(backbone_hidden_size, vocab_size),
]
)
def forward(
self,
backbone_output: BaseModelOutput,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> MaskedLMOutput:
prediction_scores = self.head(backbone_output.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
prediction_scores.view(-1, prediction_scores.size(-1)),
labels.view(-1),
)
return MaskedLMOutput(loss=loss, logits=prediction_scores)
class CLSHead(nn.Module):
def __init__(
self,
backbone_hidden_size: int,
n_classes: int,
cls_head_config: Dict[str, Any],
):
super().__init__()
self.config = cls_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **cls_head_config),
nn.Linear(backbone_hidden_size, n_classes),
]
)
def forward(
self,
backbone_output: BaseModelOutput,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
if self.config["pool_type"] == "cls":
embeddings = cls_pooling(backbone_output.last_hidden_state, attention_mask)
else:
raise NotImplementedError(
f"Pooling type {self.config['pool_type']} not implemented"
)
prediction_scores = self.head(embeddings)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
prediction_scores.view(-1, prediction_scores.size(-1)),
labels.view(-1),
)
return SequenceClassifierOutput(loss=loss, logits=prediction_scores)
class ForwardRouting(Enum):
GET_SENTENCE_EMBEDDING = "get_sentence_embedding"
GET_MLM_OUTPUT = "get_mlm_output"
GET_CLS_OUTPUT = "get_cls_output"
class ILKTModel(PreTrainedModel):
config_class = ILKTConfig
def __init__(self, config: ILKTConfig):
super().__init__(config)
backbone_config = AutoConfig.from_pretrained(**config.backbone_config)
pretrained_model_name_or_path = config.backbone_config[
"pretrained_model_name_or_path"
]
self.backbone = AutoModel.from_pretrained(
pretrained_model_name_or_path, config=backbone_config
)
backbone_hidden_size = backbone_config.hidden_size
self.config.hidden_size = backbone_hidden_size
backbone_vocab_size = backbone_config.vocab_size
self.embedding_head = SentenceEmbeddingHead(
backbone_hidden_size, config.embedding_head_config
)
self.mlm_head = MLMHead(
backbone_hidden_size, backbone_vocab_size, config.mlm_head_config
)
self.cls_heads = nn.ModuleDict(
dict(
[
(
name,
CLSHead(
backbone_hidden_size, n_classes, config.cls_head_config
),
)
for n_classes, name in config.cls_heads
]
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
forward_routing: ForwardRouting = ForwardRouting.GET_SENTENCE_EMBEDDING,
**kwargs,
):
if forward_routing == ForwardRouting.GET_SENTENCE_EMBEDDING:
return self.get_sentence_embedding(
input_ids, attention_mask, token_type_ids=token_type_ids
)
elif forward_routing == ForwardRouting.GET_MLM_OUTPUT:
return self.get_mlm_output(
input_ids, attention_mask, token_type_ids=token_type_ids, **kwargs
)
elif forward_routing == ForwardRouting.GET_CLS_OUTPUT:
return self.get_cls_output(
input_ids, attention_mask, token_type_ids=token_type_ids, **kwargs
)
else:
raise ValueError(f"Unknown forward routing {forward_routing}")
def get_sentence_embedding(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
embedding_output = self.embedding_head(
backbone_output, attention_mask, **kwargs
)
return embedding_output
def get_mlm_output(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
mlm_output = self.mlm_head(backbone_output, attention_mask, labels, **kwargs)
return mlm_output
def get_cls_output(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
head_name: str,
labels: Optional[torch.Tensor] = None,
**kwargs,
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
if head_name not in self.cls_heads:
raise ValueError(f"Head {head_name} not found in model")
cls_output = self.cls_heads[head_name](
backbone_output, attention_mask, labels, **kwargs
)
return cls_output