from typing import Union import torch from transformers import AutoTokenizer class Tokenizer: def __init__(self, model_name: str, max_len: int) -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_len = max_len def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]: return self.tokenizer( x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt" ) # type: ignore def decode(self, x: dict[str, torch.LongTensor]) -> list[str]: return [ self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1)) ]