jdrechsel's picture
Update README.md
6cb30f5
metadata
datasets:
  - ddrg/named_math_formulas
  - ddrg/math_formula_retrieval
  - ddrg/math_formulas
  - ddrg/math_text

Pretrained model based on microsoft/deberta-v3-base with further mathematical pre-training.

Compared to deberta-v3-base, 300 additional mathematical LaTeX tokens have been added before the mathematical pre-training. As this additional pre-training used NSP-like tasks, a pooling layer has been added to the model (bias and weight). If you don't need this pooling layer, just use the standard transformers DeBERTa model. If you want to use the additional pooling layer like the BERT one, a wrapper class like the following may be used:

from typing import Mapping, Any

import torch
from torch import nn
from transformers import DebertaV2Model, DebertaV2Tokenizer, AutoConfig, AutoTokenizer

class DebertaV2ModelWithPoolingLayer:

    def __init__(self, pretrained_model_name):
        super(DebertaV2ModelWithPoolingLayer, self).__init__()

        # Load the Deberta model and tokenizer
        self.deberta = DebertaV2Model.from_pretrained(pretrained_model_name)
        self.tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name)

        # Add a pooling layer (Linear + tanh activation) for the CLS token
        self.pooling_layer = nn.Sequential(
            nn.Linear(self.deberta.config.hidden_size, self.deberta.config.hidden_size),
            nn.Tanh()
        )

        self.config = self.deberta.config
        self.embeddings = self.deberta.embeddings


    def forward(self, input_ids, attention_mask, *args, **kwargs):
        # Forward pass through the Deberta model
        outputs = self.deberta(input_ids, attention_mask=attention_mask, *args, **kwargs)

        # Extract the hidden states from the output
        hidden_states = outputs.last_hidden_state

        # Get the CLS token representation (first token)
        cls_token = hidden_states[:, 0, :]

        # Apply the pooling layer to the CLS token representation
        pooled_output = self.pooling_layer(cls_token)
        # Include the pooled_output in the output dictionary as 'pooling_layer'
        outputs["pooler_output"] = pooled_output

        return outputs

    def save_pretrained(self, path):
        # Save the model's state_dict, configuration, and tokenizer
        state_dict = self.deberta.state_dict()
        state_dict.update(self.pooling_layer[0].state_dict())

        torch.save(state_dict, f"{path}/pytorch_model.bin")
        self.deberta.config.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        pooler_keys = ['bias', 'weight']
        deberta_state_dict = {k: v for k, v in state_dict.items() if k not in pooler_keys}
        pooler_state_dict = {k: v for k, v in state_dict.items() if k in pooler_keys}
        self.deberta.load_state_dict(deberta_state_dict, strict=strict)
        self.pooling_layer[0].load_state_dict(pooler_state_dict)

    @classmethod
    def from_pretrained(cls, name):
        # Initialize the instance
        instance = cls(name)

        try:
            # Load the model's state_dict
            instance.load_state_dict(torch.load(f"{name}/pytorch_model.bin"))
        except FileNotFoundError:
            print("Could not find DeBERTa pooling layer. Initialize new values")

        # Load the configuration and tokenizer
        instance.deberta.config = AutoConfig.from_pretrained(name)
        instance.tokenizer = AutoTokenizer.from_pretrained(name)

        return instance