File size: 2,271 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""ST Interface module."""

from espnet.nets.asr_interface import ASRInterface
from espnet.utils.dynamic_import import dynamic_import


class STInterface(ASRInterface):
    """ST Interface for ESPnet model implementation.

    NOTE: This class is inherited from ASRInterface to enable joint translation
    and recognition when performing multi-task learning with the ASR task.

    """

    def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]):
        """Recognize x for evaluation.

        :param ndarray x: input acouctic feature (B, T, D) or (T, D)
        :param namespace trans_args: argment namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        raise NotImplementedError("translate method is not implemented")

    def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
        """Beam search implementation for batch.

        :param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
        :param namespace trans_args: argument namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        raise NotImplementedError("Batch decoding is not supported yet.")


predefined_st = {
    "pytorch": {
        "rnn": "espnet.nets.pytorch_backend.e2e_st:E2E",
        "transformer": "espnet.nets.pytorch_backend.e2e_st_transformer:E2E",
    },
    # "chainer": {
    #     "rnn": "espnet.nets.chainer_backend.e2e_st:E2E",
    #     "transformer": "espnet.nets.chainer_backend.e2e_st_transformer:E2E",
    # }
}


def dynamic_import_st(module, backend):
    """Import ST models dynamically.

    Args:
        module (str): module_name:class_name or alias in `predefined_st`
        backend (str): NN backend. e.g., pytorch, chainer

    Returns:
        type: ST class

    """
    model_class = dynamic_import(module, predefined_st.get(backend, dict()))
    assert issubclass(
        model_class, STInterface
    ), f"{module} does not implement STInterface"
    return model_class