File size: 3,769 Bytes
d742904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd283a7
d742904
 
 
 
 
 
 
c3eb5d5
d742904
 
 
cd283a7
 
 
 
 
 
 
 
0c4eba3
d742904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Processerでの実施事項
# - TokenizerでTokenize
# - 時系列データをdataframe, numpy array, torch tensorの状態からtorch tensor化
# input_ids: , attention_mask: , time_series_values: の形式で返す。

from typing import List, Optional, Union

from pandas import DataFrame
import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp

from transformers import ProcessorMixin
from transformers import TensorType
from transformers import BatchFeature
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy


class MistsProcessor(ProcessorMixin):

    # 本来はMoment側のTokenizerもts_tokenizerとして入れたかったが、モデルに組み込まれてしまっている。
    # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L105

    # この2パーツが本来はts_tokenizerの領分になる気がする。
    # (normalizer): RevIN()
    # (tokenizer): Patching()
    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "AutoFeatureExtractor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, feature_extractor=None, tokenizer=None):
        super().__init__(feature_extractor, tokenizer)


    def __call__(
        self,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        time_series: Union[DataFrame, np.ndarray, torch.Tensor, List[DataFrame], List[np.ndarray], List[torch.Tensor]] = None,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Union[int, None] = None,
        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
        torch_dtype: Optional[Union[str, torch.dtype]] = torch.float,
        time_series_padding: Union[bool, str] = False,
        time_series_max_length: Union[int, None] = None,
        text_tokenize: bool = True,
    ) -> BatchFeature:
        if time_series is not None:
            time_series_values = self.feature_extractor(
                time_series, 
                return_tensors=return_tensors, 
                torch_dtype=torch_dtype, 
                padding=time_series_padding, 
                max_length=time_series_max_length
            )
        else:
            time_series_values = None
        if text is not None: 
            if text_tokenize:
                text_inputs = self.tokenizer(
                    text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
                )
            else:
                text_inputs = {"text": text}
        else:
            text_inputs = {}

        return BatchFeature(data={**text_inputs, **time_series_values})
    
    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        return self.tokenizer.decode(*args, **kwargs)
    
    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        feature_extractor_input_names = self.feature_extractor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))