|
"""This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model plus the expanding embedding layer for additional tokens like BOS, EOS and Speakers . |
|
|
|
Transformer from HuggingFace needs to be installed: |
|
https://huggingface.co/transformers/installation.html |
|
|
|
Authors |
|
* Pooneh Mousavi 2023 |
|
""" |
|
|
|
import logging |
|
from torch import Tensor |
|
import torch |
|
import torch.nn as nn |
|
from speechbrain.lobes.models.huggingface_gpt import HuggingFaceGPT |
|
try: |
|
from transformers import GPT2LMHeadModel |
|
from transformers import GPT2Tokenizer |
|
except ImportError: |
|
MSG = "Please install transformers from HuggingFace to use GPT2\n" |
|
MSG += "E.G. run: pip install transformers" |
|
raise ImportError(MSG) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HuggingFaceGPT_expanded(HuggingFaceGPT): |
|
"""This lobe enables the integration of HuggingFace pretrained GPT model. |
|
Source paper whisper: |
|
https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf |
|
Transformer from HuggingFace needs to be installed: |
|
https://huggingface.co/transformers/installation.html |
|
|
|
The model can be finetuned. It will download automatically the model from |
|
HuggingFace or use a local path. |
|
|
|
Arguments |
|
--------- |
|
source : str |
|
HuggingFace hub name: e.g "gpt2" |
|
save_path : str |
|
Path (dir) of the downloaded model. |
|
freeze : bool (default: False) |
|
If True, the model is frozen. If False, the model will be trained |
|
alongside with the rest of the pipeline. |
|
Example |
|
------- |
|
>>> model_hub = "gpt2" |
|
>>> save_path = "savedir" |
|
>>> model = HuggingFaceGPT(model_hub, save_path) |
|
>>> tokens = torch.tensor([[1, 1]]) |
|
>>> tokens_type = torch.tensor([[1, 1]]) |
|
>>> attention_mask = torch.tensor([[1, 1]]) |
|
>>> outputs = model(tokens, tokens_type, attention_mask) |
|
""" |
|
|
|
def __init__( |
|
self, *args, **kwrds |
|
) -> None: |
|
super().__init__( *args, **kwrds) |
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained(kwrds['source'], pad_token=None) |
|
|
|
|
|
bos_token = "BOS" |
|
eos_token="EOS" |
|
|
|
system_token= "SPK_1" |
|
user_token= "SPK_2" |
|
|
|
additional_special_tokens= [ |
|
system_token, |
|
user_token |
|
] |
|
|
|
attr_to_special_tokens={"bos_token": bos_token, |
|
"eos_token": eos_token, |
|
"additional_special_tokens": additional_special_tokens} |
|
|
|
|
|
self.add_special_tokens_( |
|
attr_to_special_tokens |
|
) |
|
|
|
def add_special_tokens_(self, attr_to_special_token,) -> None: |
|
orig_num_tokens = len(self.tokenizer.encoder) |
|
num_added_tokens = self.tokenizer.add_special_tokens( |
|
attr_to_special_token |
|
) |
|
if num_added_tokens > 0: |
|
self.model.resize_token_embeddings( |
|
new_num_tokens=orig_num_tokens + num_added_tokens |
|
) |
|
|
|
|