Polos-Demo / polos /tokenizers_ /hf_tokenizer.py
yuwd's picture
init
03f6091
raw
history blame
1.32 kB
# -*- coding: utf-8 -*-
import torch
from transformers import AutoTokenizer
from torchnlp.encoders.text.text_encoder import TextEncoder
from .tokenizer_base import TextEncoderBase
class HFTextEncoder(TextEncoderBase):
"""
Wrapper arround transformers AutoTokenizer:
- https://huggingface.co/transformers/model_doc/auto.html#autotokenizer
:param model: model to be used.
"""
def __init__(self, model: str) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
# Properties from the base class
self.stoi = self.tokenizer.get_vocab()
self.itos = {v: k for k, v in self.stoi.items()}
self._bos_index = self.tokenizer.cls_token_id
self._pad_index = self.tokenizer.pad_token_id
self._eos_index = self.tokenizer.sep_token_id
self._unk_index = self.tokenizer.unk_token_id
self._mask_index = self.tokenizer.mask_token_id
def encode(self, sequence: str) -> torch.Tensor:
"""Encodes a 'sequence'.
:param sequence: String 'sequence' to encode.
Returns:
- torch.Tensor: Encoding of the 'sequence'.
"""
sequence = TextEncoder.encode(self, sequence)
return torch.tensor(self.tokenizer(sequence, truncation=False)["input_ids"])