from typing import TYPE_CHECKING, List, Optional, Tuple from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding from transformers.utils import logging, TensorType, to_py_obj try: from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer from ariautils.tokenizer._base import Token except ImportError: raise ImportError( "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`." ) if TYPE_CHECKING: pass logger = logging.get_logger(__name__) class AriaTokenizer(PreTrainedTokenizer): """ Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule. For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts: ... ... This way, we expect a continuation that connects PROMPT and GUIDANCE. """ vocab_files_names = {} model_input_names = ["input_ids", "attention_mask"] def __init__( self, add_bos_token=True, add_eos_token=False, clean_up_tokenization_spaces=False, use_default_system_prompt=False, **kwargs, ): self._tokenizer = AbsTokenizer() self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.use_default_system_prompt = use_default_system_prompt bos_token = self._tokenizer.bos_tok eos_token = self._tokenizer.eos_tok pad_token = self._tokenizer.pad_tok unk_token = self._tokenizer.unk_tok super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, use_default_system_prompt=use_default_system_prompt, **kwargs, ) def __getstate__(self): return {} def __setstate__(self, d): raise NotImplementedError() @property def vocab_size(self): """Returns vocab size""" return self._tokenizer.vocab_size def get_vocab(self): return self._tokenizer.tok_to_id def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: return self._tokenizer(midi_dict) def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: return self._tokenizer(midi_dict) def __call__( self, midi_dicts: MidiDict | list[MidiDict], padding: bool = False, max_length: int | None = None, pad_to_multiple_of: int | None = None, return_tensors: str | TensorType | None = None, return_attention_mask: bool | None = None, **kwargs, ) -> BatchEncoding: """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design.""" if isinstance(midi_dicts, MidiDict): midi_dicts = [midi_dicts] all_tokens: list[list[int]] = [] all_attn_masks: list[list[int]] = [] max_len_encoded = 0 # TODO: if we decide to optimize batched tokenization on ariautils using some compiled backend, we can change this loop accordingly. for md in midi_dicts: tokens = self._tokenizer.encode(self._tokenizer.tokenize(md)) if max_length is not None: tokens = tokens[:max_length] max_len_encoded = max(max_len_encoded, len(tokens)) all_tokens.append(tokens) all_attn_masks.append([True] * len(tokens)) if pad_to_multiple_of is not None: max_len_encoded = ( (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of ) * pad_to_multiple_of if padding: for tokens, attn_mask in zip(all_tokens, all_attn_masks): tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens))) attn_mask.extend([False] * (max_len_encoded - len(tokens))) return BatchEncoding( { "input_ids": all_tokens, "attention_masks": all_attn_masks, }, tensor_type=return_tensors, ) def decode(self, token_ids: List[Token], **kwargs) -> MidiDict: token_ids = to_py_obj(token_ids) return self._tokenizer.detokenize(self._tokenizer.decode(token_ids)) def batch_decode( self, token_ids_list: List[List[Token]], **kwargs ) -> List[MidiDict]: results = [] for token_ids in token_ids_list: # Can we simply yield (without breaking all HF wrappers)? results.append(self.decode(token_ids)) return results def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding: midi_dict = MidiDict.from_midi(filename) return self(midi_dict, **kwargs) def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding: midi_dicts = [MidiDict.from_midi(file) for file in filenames] return self(midi_dicts, **kwargs) def _convert_token_to_id(self, token: Token): """Converts a token (tuple or str) into an id.""" return self._tokenizer.tok_to_id.get( token, self._tokenizer.tok_to_id[self.unk_token] ) def _convert_id_to_token(self, index: int): """Converts an index (integer) in a token (tuple or str).""" return self._tokenizer.id_to_tok.get(index, self.unk_token) def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict: """Converts a sequence of tokens into a single MidiDict.""" return self._tokenizer.detokenize(tokens) def save_vocabulary( self, save_directory, filename_prefix: Optional[str] = None ) -> Tuple[str]: raise NotImplementedError()