Ngaima Sandiman commited on
Commit
685ecb2
·
1 Parent(s): 8fd0e3f

Initial commit.

Browse files
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
+ os.environ["USER"] = "imagecraft"
7
+
8
+
9
+ import gradio as gr
10
+ from src.model.modules.imagecraft import ImageCraft
11
+
12
+ model = ImageCraft.from_pretrained("nsandiman/imagecraft-ft-co-224")
13
+
14
+
15
+ def imagecraft_interface(image_path):
16
+ """Process image inputs and generate audio response."""
17
+ transcript, audio_buffer = model.generate(image_path, output_type="buffer")
18
+
19
+ return audio_buffer, transcript
20
+
21
+
22
+ # Define Gradio interface
23
+ gradio_interface = gr.Interface(
24
+ fn=imagecraft_interface,
25
+ inputs=[
26
+ gr.Image(type="filepath", label="Upload an image"),
27
+ gr.Textbox(label="Reference Text (for evaluation)"),
28
+ ],
29
+ outputs=[gr.Audio(label="Speech"), gr.Textbox(label="Transcript")],
30
+ title="ImageCraft",
31
+ description="Upload an image and get the speech responses.",
32
+ allow_flagging="never",
33
+ )
34
+
35
+ # Launch the Gradio app
36
+ gradio_interface.launch(share=True)
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nsandiman/imagecraft-ft-co-224",
3
+ "_vocab_size": 257216,
4
+ "architectures": [
5
+ "PaliGemmaForConditionalGeneration"
6
+ ],
7
+ "bos_token_id": 2,
8
+ "eos_token_id": 1,
9
+ "hidden_size": 2048,
10
+ "ignore_index": -100,
11
+ "image_token_index": 257152,
12
+ "model_type": "paligemma",
13
+ "pad_token_id": 0,
14
+ "projection_dim": 2048,
15
+ "text_config": {
16
+ "hidden_size": 2048,
17
+ "intermediate_size": 16384,
18
+ "model_type": "gemma",
19
+ "num_attention_heads": 8,
20
+ "num_hidden_layers": 18,
21
+ "num_image_tokens": 256,
22
+ "num_key_value_heads": 1,
23
+ "torch_dtype": "float32",
24
+ "vocab_size": 257216
25
+ },
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.41.0.dev0",
28
+ "vision_config": {
29
+ "hidden_size": 1152,
30
+ "intermediate_size": 4304,
31
+ "model_type": "siglip_vision_model",
32
+ "num_attention_heads": 16,
33
+ "num_hidden_layers": 27,
34
+ "num_image_tokens": 256,
35
+ "patch_size": 14,
36
+ "projection_dim": 2048,
37
+ "projector_hidden_act": "gelu_fast",
38
+ "vision_use_head": false
39
+ },
40
+ "vocab_size": 257216,
41
+ "voicecraft_config": {
42
+ "model_name": "330M_TTSEnhanced.pth",
43
+ "encoded": "encodec_4cb2048_giga.th",
44
+ "voice_audio_path": "84_121550_000074_000000.wav",
45
+ "voice_audio_transcript": "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks",
46
+ "top_k": 0,
47
+ "top_p": 0.9,
48
+ "temperature": 1,
49
+ "kvcache": 1,
50
+ "codec_sr": 50,
51
+ "codec_audio_sr": 16000,
52
+ "silence_tokens": [1388, 1898, 131],
53
+ "stop_repetition": 3,
54
+ "sample_batch_size": 2,
55
+ "seed": 1,
56
+ "cut_off_sec": 7.87
57
+ }
58
+ }
config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_flickr: nsandiman/imagecraft-pt-fk-224
2
+ model_coco: models/imagecraft-pt-co-224
3
+ model_tiny: models/imagecraft-pt-ty-224
4
+ checkpoint_dir: models/checkpoint
5
+ pretrained_dir: models/pretrained
6
+ model_dir: models
7
+ data:
8
+ raw_dir: data/raw
9
+ interim_dir: data/interim
10
+ processed_dir: data/processed
11
+ log_dir: data/logs
12
+ wandb_dir: data/wandb
13
+ tensorboard_log_dir: data/tensorboard/logs
14
+
media/voicecraft/generated/empty.txt ADDED
File without changes
media/voicecraft/voices/84_121550_000074_000000.wav ADDED
Binary file (508 kB). View file
 
media/voicecraft/voices/mfa_alignments/84_121550_000074_000000.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Begin,End,Label,Type,Speaker
2
+ 0.03,0.18,but,words,temp
3
+ 0.18,0.32,when,words,temp
4
+ 0.32,0.48,i,words,temp
5
+ 0.48,0.64,had,words,temp
6
+ 0.64,1.19,approached,words,temp
7
+ 1.22,1.58,so,words,temp
8
+ 1.58,1.91,near,words,temp
9
+ 1.91,2.07,to,words,temp
10
+ 2.07,2.42,them,words,temp
11
+ 2.53,2.61,the,words,temp
12
+ 2.61,3.01,common,words,temp
13
+ 3.05,3.62,object,words,temp
14
+ 3.68,3.93,which,words,temp
15
+ 3.93,4.02,the,words,temp
16
+ 4.02,4.34,sense,words,temp
17
+ 4.34,4.97,deceives,words,temp
18
+ 5.04,5.54,lost,words,temp
19
+ 5.54,6.0,not,words,temp
20
+ 6.0,6.14,by,words,temp
21
+ 6.14,6.67,distance,words,temp
22
+ 6.79,7.05,any,words,temp
23
+ 7.05,7.18,of,words,temp
24
+ 7.18,7.34,its,words,temp
25
+ 7.34,7.87,marks,words,temp
26
+ 0.03,0.06,B,phones,temp
27
+ 0.06,0.09,AH1,phones,temp
28
+ 0.09,0.18,T,phones,temp
29
+ 0.18,0.23,W,phones,temp
30
+ 0.23,0.27,EH1,phones,temp
31
+ 0.27,0.32,N,phones,temp
32
+ 0.32,0.48,AY1,phones,temp
33
+ 0.48,0.49,HH,phones,temp
34
+ 0.49,0.6,AE1,phones,temp
35
+ 0.6,0.64,D,phones,temp
36
+ 0.64,0.7,AH0,phones,temp
37
+ 0.7,0.83,P,phones,temp
38
+ 0.83,0.88,R,phones,temp
39
+ 0.88,0.99,OW1,phones,temp
40
+ 0.99,1.12,CH,phones,temp
41
+ 1.12,1.19,T,phones,temp
42
+ 1.22,1.4,S,phones,temp
43
+ 1.4,1.58,OW1,phones,temp
44
+ 1.58,1.7,N,phones,temp
45
+ 1.7,1.84,IH1,phones,temp
46
+ 1.84,1.91,R,phones,temp
47
+ 1.91,2.01,T,phones,temp
48
+ 2.01,2.07,AH0,phones,temp
49
+ 2.07,2.13,DH,phones,temp
50
+ 2.13,2.3,EH1,phones,temp
51
+ 2.3,2.42,M,phones,temp
52
+ 2.53,2.55,DH,phones,temp
53
+ 2.55,2.61,AH0,phones,temp
54
+ 2.61,2.73,K,phones,temp
55
+ 2.73,2.85,AA1,phones,temp
56
+ 2.85,2.9,M,phones,temp
57
+ 2.9,2.95,AH0,phones,temp
58
+ 2.95,3.01,N,phones,temp
59
+ 3.05,3.22,AA1,phones,temp
60
+ 3.22,3.27,B,phones,temp
61
+ 3.27,3.34,JH,phones,temp
62
+ 3.34,3.48,EH0,phones,temp
63
+ 3.48,3.54,K,phones,temp
64
+ 3.54,3.62,T,phones,temp
65
+ 3.68,3.69,HH,phones,temp
66
+ 3.69,3.76,W,phones,temp
67
+ 3.76,3.8,IH1,phones,temp
68
+ 3.8,3.93,CH,phones,temp
69
+ 3.93,3.95,DH,phones,temp
70
+ 3.95,4.02,AH0,phones,temp
71
+ 4.02,4.12,S,phones,temp
72
+ 4.12,4.21,EH1,phones,temp
73
+ 4.21,4.27,N,phones,temp
74
+ 4.27,4.34,S,phones,temp
75
+ 4.34,4.42,D,phones,temp
76
+ 4.42,4.45,IH0,phones,temp
77
+ 4.45,4.59,S,phones,temp
78
+ 4.59,4.79,IY1,phones,temp
79
+ 4.79,4.87,V,phones,temp
80
+ 4.87,4.97,Z,phones,temp
81
+ 5.04,5.12,L,phones,temp
82
+ 5.12,5.33,AO1,phones,temp
83
+ 5.33,5.42,S,phones,temp
84
+ 5.42,5.54,T,phones,temp
85
+ 5.54,5.7,N,phones,temp
86
+ 5.7,5.89,AA1,phones,temp
87
+ 5.89,6.0,T,phones,temp
88
+ 6.0,6.05,B,phones,temp
89
+ 6.05,6.14,AY1,phones,temp
90
+ 6.14,6.24,D,phones,temp
91
+ 6.24,6.3,IH1,phones,temp
92
+ 6.3,6.38,S,phones,temp
93
+ 6.38,6.45,T,phones,temp
94
+ 6.45,6.51,AH0,phones,temp
95
+ 6.51,6.57,N,phones,temp
96
+ 6.57,6.67,S,phones,temp
97
+ 6.79,6.89,EH1,phones,temp
98
+ 6.89,6.95,N,phones,temp
99
+ 6.95,7.05,IY0,phones,temp
100
+ 7.05,7.13,AH0,phones,temp
101
+ 7.13,7.18,V,phones,temp
102
+ 7.18,7.22,IH0,phones,temp
103
+ 7.22,7.29,T,phones,temp
104
+ 7.29,7.34,S,phones,temp
105
+ 7.34,7.39,M,phones,temp
106
+ 7.39,7.5,AA1,phones,temp
107
+ 7.5,7.58,R,phones,temp
108
+ 7.58,7.7,K,phones,temp
109
+ 7.7,7.87,S,phones,temp
models/pretrained/imagecraft/empty.txt ADDED
File without changes
models/pretrained/voicecraft/empty.txt ADDED
File without changes
packages.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ espeak-ng
2
+ espeak
3
+ espeak-data
4
+ libespeak1
5
+ libespeak-dev
6
+ festival*
7
+ build-essential
8
+ flac
9
+ libasound2-dev
10
+ libsndfile1-dev
11
+ vorbis-tools
12
+ libxml2-dev
13
+ libxslt-dev
14
+ zlib1g-dev
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/facebookresearch/audiocraft.git@f83babff6b5e97f75562127c4cc8122229c8f099#egg=audiocraft
2
+ git+https://github.com/huggingface/transformers.git
3
+ phonemizer
4
+ spaces
5
+ huggingface-hub
6
+ num2words
7
+ numpy
8
+ pillow
9
+ safetensors
10
+ tokenizers
11
+ torchaudio
12
+ torchvision
13
+ aeneas
setup.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from setuptools import setup, find_packages
3
+
4
+ if platform.python_version_tuple()[:2] != ("3", "11"):
5
+ raise RuntimeError("Python version 3.11 required")
6
+
7
+ setup(
8
+ name="distilvit",
9
+ version="0.1",
10
+ packages=find_packages(),
11
+ entry_points={
12
+ "console_scripts": [
13
+ "train=src.model.train:main", # "main" is a function in "train_model.py"
14
+ ],
15
+ },
16
+ )
src/__init__.py ADDED
File without changes
src/model/modules/__init__.py ADDED
File without changes
src/model/modules/activation.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+ import logging
12
+ from typing import Callable, List, Optional, Tuple, Union
13
+ from typing import TYPE_CHECKING
14
+ if TYPE_CHECKING:
15
+ from torch.types import _dtype as DType
16
+ else:
17
+ # The JIT doesn't understand Union, nor torch.dtype here
18
+ DType = int
19
+
20
+ def _canonical_mask(
21
+ mask: Optional[Tensor],
22
+ mask_name: str,
23
+ other_type: Optional[DType],
24
+ other_name: str,
25
+ target_type: DType,
26
+ check_other: bool = True,
27
+ ) -> Optional[Tensor]:
28
+
29
+ if mask is not None:
30
+ _mask_dtype = mask.dtype
31
+ _mask_is_float = torch.is_floating_point(mask)
32
+ if _mask_dtype != torch.bool and not _mask_is_float:
33
+ raise AssertionError(
34
+ f"only bool and floating types of {mask_name} are supported")
35
+ if check_other and other_type is not None:
36
+ if _mask_dtype != other_type:
37
+ warnings.warn(
38
+ f"Support for mismatched {mask_name} and {other_name} "
39
+ "is deprecated. Use same type for both instead."
40
+ )
41
+ if not _mask_is_float:
42
+ mask = (
43
+ torch.zeros_like(mask, dtype=target_type)
44
+ .masked_fill_(mask, float("-inf"))
45
+ )
46
+ return mask
47
+
48
+ def _in_projection_packed(
49
+ q: Tensor,
50
+ k: Tensor,
51
+ v: Tensor,
52
+ w: Tensor,
53
+ b: Optional[Tensor] = None,
54
+ ) -> List[Tensor]:
55
+ r"""
56
+ Performs the in-projection step of the attention operation, using packed weights.
57
+ Output is a triple containing projection tensors for query, key and value.
58
+ Args:
59
+ q, k, v: query, key and value tensors to be projected. For self-attention,
60
+ these are typically the same tensor; for encoder-decoder attention,
61
+ k and v are typically the same tensor. (We take advantage of these
62
+ identities for performance if they are present.) Regardless, q, k and v
63
+ must share a common embedding dimension; otherwise their shapes may vary.
64
+ w: projection weights for q, k and v, packed into a single tensor. Weights
65
+ are packed along dimension 0, in q, k, v order.
66
+ b: optional projection biases for q, k and v, packed into a single tensor
67
+ in q, k, v order.
68
+ Shape:
69
+ Inputs:
70
+ - q: :math:`(..., E)` where E is the embedding dimension
71
+ - k: :math:`(..., E)` where E is the embedding dimension
72
+ - v: :math:`(..., E)` where E is the embedding dimension
73
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
74
+ - b: :math:`E * 3` where E is the embedding dimension
75
+ Output:
76
+ - in output list :math:`[q', k', v']`, each output tensor will have the
77
+ same shape as the corresponding input tensor.
78
+ """
79
+ E = q.size(-1)
80
+ if k is v:
81
+ if q is k:
82
+ # self-attention
83
+ proj = F.linear(q, w, b)
84
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
85
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
86
+ return proj[0], proj[1], proj[2]
87
+ else:
88
+ # encoder-decoder attention
89
+ w_q, w_kv = w.split([E, E * 2])
90
+ if b is None:
91
+ b_q = b_kv = None
92
+ else:
93
+ b_q, b_kv = b.split([E, E * 2])
94
+ q_proj = F.linear(q, w_q, b_q)
95
+ kv_proj = F.linear(k, w_kv, b_kv)
96
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
97
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
98
+ return (q_proj, kv_proj[0], kv_proj[1])
99
+ else:
100
+ w_q, w_k, w_v = w.chunk(3)
101
+ if b is None:
102
+ b_q = b_k = b_v = None
103
+ else:
104
+ b_q, b_k, b_v = b.chunk(3)
105
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
106
+
107
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
108
+ if input is None:
109
+ return None
110
+ elif isinstance(input, torch.Tensor):
111
+ return input.dtype
112
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
113
+ class MultiheadAttention(Module):
114
+ r"""Allows the model to jointly attend to information
115
+ from different representation subspaces as described in the paper:
116
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
117
+ Multi-Head Attention is defined as:
118
+ .. math::
119
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
120
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
121
+ ``forward()`` will use a special optimized implementation if all of the following
122
+ conditions are met:
123
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
124
+ restriction will be loosened in the future.)
125
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
126
+ - training is disabled (using ``.eval()``)
127
+ - dropout is 0
128
+ - ``add_bias_kv`` is ``False``
129
+ - ``add_zero_attn`` is ``False``
130
+ - ``batch_first`` is ``True`` and the input is batched
131
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
132
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
133
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
134
+ nor ``attn_mask`` is passed
135
+ If the optimized implementation is in use, a
136
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
137
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
138
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
139
+ will be returned, and an additional speedup proportional to the fraction of the input
140
+ that is padding can be expected.
141
+ Args:
142
+ embed_dim: Total dimension of the model.
143
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
144
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
145
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
146
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
147
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
148
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
149
+ Default: ``False``.
150
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
151
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
152
+ batch_first: If ``True``, then the input and output tensors are provided
153
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
154
+ Examples::
155
+ >>> # xdoctest: +SKIP
156
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
157
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
158
+ """
159
+ __constants__ = ["batch_first"]
160
+ bias_k: Optional[torch.Tensor]
161
+ bias_v: Optional[torch.Tensor]
162
+
163
+ def __init__(
164
+ self,
165
+ embed_dim,
166
+ num_heads,
167
+ dropout=0.0,
168
+ bias=True,
169
+ add_bias_kv=False,
170
+ add_zero_attn=False,
171
+ kdim=None,
172
+ vdim=None,
173
+ batch_first=False,
174
+ linear1_cls=Linear,
175
+ linear2_cls=Linear,
176
+ device=None,
177
+ dtype=None,
178
+ ) -> None:
179
+ factory_kwargs = {"device": device, "dtype": dtype}
180
+ super(MultiheadAttention, self).__init__()
181
+ self.embed_dim = embed_dim
182
+ self.kdim = kdim if kdim is not None else embed_dim
183
+ self.vdim = vdim if vdim is not None else embed_dim
184
+ self._qkv_same_embed_dim = (
185
+ self.kdim == embed_dim and self.vdim == embed_dim
186
+ )
187
+
188
+ self.num_heads = num_heads
189
+ self.dropout = dropout
190
+ self.batch_first = batch_first
191
+ self.head_dim = embed_dim // num_heads
192
+ assert (
193
+ self.head_dim * num_heads == self.embed_dim
194
+ ), "embed_dim must be divisible by num_heads"
195
+
196
+ if add_bias_kv:
197
+ self.bias_k = Parameter(
198
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
199
+ )
200
+ self.bias_v = Parameter(
201
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
202
+ )
203
+ else:
204
+ self.bias_k = self.bias_v = None
205
+
206
+ if linear1_cls == Linear:
207
+ if not self._qkv_same_embed_dim:
208
+ self.q_proj_weight = Parameter(
209
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
210
+ )
211
+ self.k_proj_weight = Parameter(
212
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
213
+ )
214
+ self.v_proj_weight = Parameter(
215
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
216
+ )
217
+ self.register_parameter("in_proj_weight", None)
218
+ else:
219
+ # go down this route with voicecraft
220
+ self.in_proj_weight = Parameter(
221
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
222
+ )
223
+ self.register_parameter("q_proj_weight", None)
224
+ self.register_parameter("k_proj_weight", None)
225
+ self.register_parameter("v_proj_weight", None)
226
+
227
+ if bias: # True by default
228
+ self.in_proj_bias = Parameter(
229
+ torch.empty(3 * embed_dim, **factory_kwargs)
230
+ )
231
+ else:
232
+ self.register_parameter("in_proj_bias", None)
233
+ self.out_proj = NonDynamicallyQuantizableLinear(
234
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
235
+ )
236
+
237
+ self._reset_parameters()
238
+ else:
239
+ if not self._qkv_same_embed_dim:
240
+ raise NotImplementedError
241
+ else:
242
+ self.in_proj_linear = linear1_cls(
243
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
244
+ )
245
+ self.in_proj_weight = self.in_proj_linear.weight
246
+
247
+ self.register_parameter("q_proj_weight", None)
248
+ self.register_parameter("k_proj_weight", None)
249
+ self.register_parameter("v_proj_weight", None)
250
+
251
+ if bias:
252
+ self.in_proj_bias = self.in_proj_linear.bias
253
+ else:
254
+ self.register_parameter("in_proj_bias", None)
255
+
256
+ self.out_proj = linear2_cls(
257
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
258
+ )
259
+
260
+ if self.bias_k is not None:
261
+ xavier_normal_(self.bias_k)
262
+ if self.bias_v is not None:
263
+ xavier_normal_(self.bias_v)
264
+
265
+ self.add_zero_attn = add_zero_attn
266
+
267
+ def _reset_parameters(self):
268
+ if self._qkv_same_embed_dim:
269
+ xavier_uniform_(self.in_proj_weight)
270
+ else:
271
+ xavier_uniform_(self.q_proj_weight)
272
+ xavier_uniform_(self.k_proj_weight)
273
+ xavier_uniform_(self.v_proj_weight)
274
+
275
+ if self.in_proj_bias is not None:
276
+ constant_(self.in_proj_bias, 0.0)
277
+ constant_(self.out_proj.bias, 0.0)
278
+
279
+ if self.bias_k is not None:
280
+ xavier_normal_(self.bias_k)
281
+ if self.bias_v is not None:
282
+ xavier_normal_(self.bias_v)
283
+
284
+ def __setstate__(self, state):
285
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
286
+ if "_qkv_same_embed_dim" not in state:
287
+ state["_qkv_same_embed_dim"] = True
288
+
289
+ super(MultiheadAttention, self).__setstate__(state)
290
+
291
+ def forward(
292
+ self,
293
+ query: Tensor,
294
+ key: Tensor,
295
+ value: Tensor,
296
+ key_padding_mask: Optional[Tensor] = None,
297
+ need_weights: bool = True,
298
+ attn_mask: Optional[Tensor] = None,
299
+ average_attn_weights: bool = True,
300
+ past: Optional[Tensor] = None,
301
+ ) -> Tuple[Tensor, Optional[Tensor]]:
302
+ r"""
303
+ Args:
304
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
305
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
306
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
307
+ Queries are compared against key-value pairs to produce the output.
308
+ See "Attention Is All You Need" for more details.
309
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
310
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
311
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
312
+ See "Attention Is All You Need" for more details.
313
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
314
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
315
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
316
+ See "Attention Is All You Need" for more details.
317
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
318
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
319
+ Binary and byte masks are supported.
320
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
321
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
322
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
323
+ Default: ``True``.
324
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
325
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
326
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
327
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
328
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
329
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
330
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
331
+ the attention weight.
332
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
333
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
334
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
335
+ Outputs:
336
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
337
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
338
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
339
+ embedding dimension ``embed_dim``.
340
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
341
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
342
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
343
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
344
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
345
+ .. note::
346
+ `batch_first` argument is ignored for unbatched inputs.
347
+ """
348
+ is_batched = query.dim() == 3
349
+ if key_padding_mask is not None:
350
+ _kpm_dtype = key_padding_mask.dtype
351
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
352
+ key_padding_mask
353
+ ):
354
+ raise AssertionError(
355
+ "only bool and floating types of key_padding_mask are supported"
356
+ )
357
+ why_not_fast_path = ""
358
+ if not is_batched:
359
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
360
+ elif query is not key or key is not value:
361
+ # When lifting this restriction, don't forget to either
362
+ # enforce that the dtypes all match or test cases where
363
+ # they don't!
364
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
365
+ elif (
366
+ self.in_proj_bias is not None
367
+ and query.dtype != self.in_proj_bias.dtype
368
+ ):
369
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
370
+ elif (
371
+ self.in_proj_weight is not None
372
+ and query.dtype != self.in_proj_weight.dtype
373
+ ):
374
+ # this case will fail anyway, but at least they'll get a useful error message.
375
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
376
+ elif self.training:
377
+ why_not_fast_path = "training is enabled"
378
+ elif not self.batch_first:
379
+ why_not_fast_path = "batch_first was not True"
380
+ elif self.bias_k is not None:
381
+ why_not_fast_path = "self.bias_k was not None"
382
+ elif self.bias_v is not None:
383
+ why_not_fast_path = "self.bias_v was not None"
384
+ elif self.dropout:
385
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
386
+ elif self.add_zero_attn:
387
+ why_not_fast_path = "add_zero_attn was enabled"
388
+ elif not self._qkv_same_embed_dim:
389
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
390
+ elif attn_mask is not None:
391
+ why_not_fast_path = "attn_mask was not None"
392
+ elif query.is_nested and key_padding_mask is not None:
393
+ why_not_fast_path = (
394
+ "key_padding_mask is not supported with NestedTensor input"
395
+ )
396
+ elif self.num_heads % 2 == 1:
397
+ why_not_fast_path = "num_heads is odd"
398
+ elif torch.is_autocast_enabled():
399
+ why_not_fast_path = "autocast is enabled"
400
+
401
+ if not why_not_fast_path:
402
+ tensor_args = (
403
+ query,
404
+ key,
405
+ value,
406
+ self.in_proj_weight,
407
+ self.in_proj_bias,
408
+ self.out_proj.weight,
409
+ self.out_proj.bias,
410
+ )
411
+ # We have to use list comprehensions below because TorchScript does not support
412
+ # generator expressions.
413
+ if torch.overrides.has_torch_function(tensor_args):
414
+ why_not_fast_path = "some Tensor argument has_torch_function"
415
+ elif not all(
416
+ [
417
+ (x is None or x.is_cuda or "cpu" in str(x.device))
418
+ for x in tensor_args
419
+ ]
420
+ ):
421
+ why_not_fast_path = (
422
+ "some Tensor argument is neither CUDA nor CPU"
423
+ )
424
+ elif torch.is_grad_enabled() and any(
425
+ [x is not None and x.requires_grad for x in tensor_args]
426
+ ):
427
+ why_not_fast_path = (
428
+ "grad is enabled and at least one of query or the "
429
+ "input/output projection weights or biases requires_grad"
430
+ )
431
+ if not why_not_fast_path:
432
+ return torch._native_multi_head_attention(
433
+ query,
434
+ key,
435
+ value,
436
+ self.embed_dim,
437
+ self.num_heads,
438
+ self.in_proj_weight,
439
+ self.in_proj_bias,
440
+ self.out_proj.weight,
441
+ self.out_proj.bias,
442
+ key_padding_mask
443
+ if key_padding_mask is not None
444
+ else attn_mask,
445
+ need_weights,
446
+ average_attn_weights,
447
+ 1
448
+ if key_padding_mask is not None
449
+ else 0
450
+ if attn_mask is not None
451
+ else None,
452
+ )
453
+
454
+ any_nested = query.is_nested or key.is_nested or value.is_nested
455
+ assert not any_nested, (
456
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
457
+ + f"The fast path was not hit because {why_not_fast_path}"
458
+ )
459
+
460
+ if self.batch_first and is_batched:
461
+ # make sure that the transpose op does not affect the "is" property
462
+ if key is value:
463
+ if query is key:
464
+ query = key = value = query.transpose(1, 0)
465
+ else:
466
+ query, key = [x.transpose(1, 0) for x in (query, key)]
467
+ value = key
468
+ else:
469
+ query, key, value = [
470
+ x.transpose(1, 0) for x in (query, key, value)
471
+ ]
472
+
473
+ if not self._qkv_same_embed_dim:
474
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
475
+ query,
476
+ key,
477
+ value,
478
+ self.embed_dim,
479
+ self.num_heads,
480
+ self.in_proj_weight,
481
+ self.in_proj_bias,
482
+ self.bias_k,
483
+ self.bias_v,
484
+ self.add_zero_attn,
485
+ self.dropout,
486
+ self.out_proj.weight,
487
+ self.out_proj.bias,
488
+ training=self.training,
489
+ key_padding_mask=key_padding_mask,
490
+ need_weights=need_weights,
491
+ attn_mask=attn_mask,
492
+ use_separate_proj_weight=True,
493
+ q_proj_weight=self.q_proj_weight,
494
+ k_proj_weight=self.k_proj_weight,
495
+ v_proj_weight=self.v_proj_weight,
496
+ average_attn_weights=average_attn_weights,
497
+ )
498
+ else:
499
+ # re-write the self.attention here, to get k, v cache
500
+ tgt_len, bsz, embed_dim = query.shape
501
+ src_len, _, _ = key.shape
502
+ num_heads = self.num_heads
503
+ key_padding_mask = _canonical_mask(
504
+ mask=key_padding_mask,
505
+ mask_name="key_padding_mask",
506
+ other_type=_none_or_dtype(attn_mask),
507
+ other_name="attn_mask",
508
+ target_type=query.dtype
509
+ )
510
+ attn_mask = _canonical_mask(
511
+ mask=attn_mask,
512
+ mask_name="attn_mask",
513
+ other_type=None,
514
+ other_name="",
515
+ target_type=query.dtype,
516
+ check_other=False,
517
+ )
518
+ head_dim = self.embed_dim // self.num_heads
519
+ assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
520
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
521
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
522
+ # k_present, v_present = k, v
523
+
524
+ #
525
+ # reshape q, k, v for multihead attention and make em batch first
526
+ #
527
+
528
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
529
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
530
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
531
+ src_len = k.size(1)
532
+ if past is not None and past.ndim > 2:
533
+ expected_src_len = src_len + past[0].shape[-2]
534
+ else:
535
+ expected_src_len = src_len
536
+
537
+
538
+ # ensure attn_mask's dim is 3
539
+ if attn_mask.dim() == 2:
540
+ correct_2d_size = (tgt_len, expected_src_len)
541
+ if attn_mask.shape != correct_2d_size:
542
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
543
+ attn_mask = attn_mask.unsqueeze(0)
544
+ elif attn_mask.dim() == 3:
545
+ correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
546
+ if attn_mask.shape != correct_3d_size:
547
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
548
+ else:
549
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
550
+
551
+ if key_padding_mask is not None:
552
+ assert key_padding_mask.shape == (bsz, expected_src_len), \
553
+ f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
554
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
555
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
556
+ if attn_mask is None:
557
+ attn_mask = key_padding_mask
558
+ else:
559
+ attn_mask = attn_mask + key_padding_mask
560
+
561
+ if not self.training:
562
+ dropout_p = 0.0
563
+ else:
564
+ dropout_p = self.dropout
565
+
566
+ if need_weights:
567
+ raise NotImplementedError("need_weights not implemented for voicecraft")
568
+ # B, Nt, E = q.shape
569
+ # q_scaled = q / math.sqrt(E)
570
+
571
+ # assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
572
+
573
+ # if attn_mask is not None:
574
+ # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
575
+ # else:
576
+ # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
577
+ # attn_output_weights = softmax(attn_output_weights, dim=-1)
578
+ # if dropout_p > 0.0:
579
+ # attn_output_weights = dropout(attn_output_weights, p=dropout_p)
580
+
581
+ # attn_output = torch.bmm(attn_output_weights, v)
582
+
583
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
584
+ # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
585
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
586
+
587
+ # # optionally average attention weights over heads
588
+ # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
589
+ # if average_attn_weights:
590
+ # attn_output_weights = attn_output_weights.mean(dim=1)
591
+
592
+ # if not is_batched:
593
+ # # squeeze the output if input was unbatched
594
+ # attn_output = attn_output.squeeze(1)
595
+ # attn_output_weights = attn_output_weights.squeeze(0)
596
+ # return attn_output, attn_output_weights
597
+ else:
598
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
599
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
600
+ # in order to match the input for SDPA of (N, num_heads, L, S)
601
+ if attn_mask is not None:
602
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
603
+ attn_mask = attn_mask.unsqueeze(0)
604
+ else:
605
+ attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
606
+
607
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
608
+ k = k.view(bsz, num_heads, src_len, head_dim)
609
+ v = v.view(bsz, num_heads, src_len, head_dim)
610
+ # logging.info(f"shape of past: {past.shape}")
611
+ if past is not None:
612
+ present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
613
+ if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
614
+ pk, pv = past
615
+ k = torch.cat([pk, k], dim=-2)
616
+ v = torch.cat([pv, v], dim=-2)
617
+ else:
618
+ present = None
619
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
620
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
621
+
622
+ attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
623
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
624
+ if not is_batched:
625
+ # squeeze the output if input was unbatched
626
+ attn_output = attn_output.squeeze(1)
627
+ # if self.training:
628
+ # return attn_output, None
629
+ # else:
630
+ # return (attn_output, present), None
631
+
632
+ # harded coded, the code do not support returning attn weigths yet
633
+ attn_output_weights=None
634
+ if self.batch_first and is_batched:
635
+ return attn_output.transpose(1, 0), present
636
+ else:
637
+ return attn_output, present
638
+
src/model/modules/codebooks_patterns.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+
19
+
20
+ @dataclass
21
+ class Pattern:
22
+ """Base implementation of a pattern over a sequence with multiple codebooks.
23
+
24
+ The codebook pattern consists in a layout, defining for each sequence step
25
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
26
+ The first item of the pattern is always an empty list in order to properly insert a special token
27
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
28
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
29
+
30
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
31
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
32
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
33
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
34
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
35
+ is returned along with a mask indicating valid tokens.
36
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
37
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
38
+ to fill and specify invalid positions if needed.
39
+ See the dedicated methods for more details.
40
+ """
41
+ # Pattern layout, for each sequence step, we have a list of coordinates
42
+ # corresponding to the original codebook timestep and position.
43
+ # The first list is always an empty list in order to properly insert
44
+ # a special token to start with.
45
+ layout: PatternLayout
46
+ timesteps: int
47
+ n_q: int
48
+
49
+ def __post_init__(self):
50
+ assert len(self.layout) > 0
51
+ assert self.layout[0] == []
52
+ self._validate_layout()
53
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
+ # logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
+
57
+ def _validate_layout(self):
58
+ """Runs checks on the layout to ensure a valid pattern is defined.
59
+ A pattern is considered invalid if:
60
+ - Multiple timesteps for a same codebook are defined in the same sequence step
61
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
62
+ (this would mean that we have future timesteps before past timesteps).
63
+ """
64
+ q_timesteps = {q: 0 for q in range(self.n_q)}
65
+ for s, seq_coords in enumerate(self.layout):
66
+ if len(seq_coords) > 0:
67
+ qs = set()
68
+ for coord in seq_coords:
69
+ qs.add(coord.q)
70
+ last_q_timestep = q_timesteps[coord.q]
71
+ assert coord.t >= last_q_timestep, \
72
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
73
+ q_timesteps[coord.q] = coord.t
74
+ # each sequence step contains at max 1 coordinate per codebook
75
+ assert len(qs) == len(seq_coords), \
76
+ f"Multiple entries for a same codebook are found at step {s}"
77
+
78
+ @property
79
+ def num_sequence_steps(self):
80
+ return len(self.layout) - 1
81
+
82
+ @property
83
+ def max_delay(self):
84
+ max_t_in_seq_coords = 0
85
+ for seq_coords in self.layout[1:]:
86
+ for coords in seq_coords:
87
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
88
+ return max_t_in_seq_coords - self.timesteps
89
+
90
+ @property
91
+ def valid_layout(self):
92
+ valid_step = len(self.layout) - self.max_delay
93
+ return self.layout[:valid_step]
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (Union[torch.device, str]): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (Union[torch.device, str]): Device for created tensors.
192
+ Returns:
193
+ torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output:
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total numer of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (Optional[List[int]]): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ out: PatternLayout = [[]]
338
+ max_delay = max(self.delays)
339
+ if self.empty_initial:
340
+ out += [[] for _ in range(self.empty_initial)]
341
+ if self.flatten_first:
342
+ for t in range(min(timesteps, self.flatten_first)):
343
+ for q in range(self.n_q):
344
+ out.append([LayoutCoord(t, q)])
345
+ for t in range(self.flatten_first, timesteps + max_delay):
346
+ v = []
347
+ for q, delay in enumerate(self.delays):
348
+ t_for_q = t - delay
349
+ if t_for_q >= self.flatten_first:
350
+ v.append(LayoutCoord(t_for_q, q))
351
+ out.append(v)
352
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
353
+
354
+
355
+ class ParallelPatternProvider(DelayedPatternProvider):
356
+ """Provider for parallel pattern across codebooks.
357
+ This pattern provider is a special case of the delayed pattern with actually no delay,
358
+ hence delays=repeat(0, n_q).
359
+
360
+ Args:
361
+ n_q (int): Number of codebooks.
362
+ """
363
+ def __init__(self, n_q: int):
364
+ super().__init__(n_q, [0] * n_q)
365
+
366
+
367
+ class UnrolledPatternProvider(CodebooksPatternProvider):
368
+ """Provider for unrolling codebooks pattern.
369
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
370
+ while also specifying a given delay between the flattened codebooks representation, allowing to
371
+ unroll the codebooks in the sequence.
372
+
373
+ Example:
374
+ 1. Flattening of the codebooks.
375
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
376
+ taking n_q = 3 and timesteps = 4:
377
+ [[1, 2, 3, 4],
378
+ [1, 2, 3, 4],
379
+ [1, 2, 3, 4]]
380
+ will result into:
381
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
382
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
383
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
384
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
385
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
386
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
387
+ [[1, 2, 3, 4],
388
+ [1, 2, 3, 4],
389
+ [1, 2, 3, 4]]
390
+ will result into:
391
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
392
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
393
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
394
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
395
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
396
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
397
+ and delays = [0, 3, 3]:
398
+ [[1, 2, 3, 4],
399
+ [1, 2, 3, 4],
400
+ [1, 2, 3, 4]]
401
+ will result into:
402
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
403
+ [S, S, S, 1, S, 2, S, 3, S, 4],
404
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
405
+
406
+ Args:
407
+ n_q (int): Number of codebooks.
408
+ flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
409
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
410
+ have n_q extra steps for each timestep.
411
+ delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
412
+ no delay is added and therefore will default to [0] * ``n_q``.
413
+ Note that two codebooks that will be flattened to the same inner step
414
+ should have the same delay, otherwise the pattern is considered as invalid.
415
+ """
416
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
417
+
418
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
419
+ delays: tp.Optional[tp.List[int]] = None):
420
+ super().__init__(n_q)
421
+ if flattening is None:
422
+ flattening = list(range(n_q))
423
+ if delays is None:
424
+ delays = [0] * n_q
425
+ assert len(flattening) == n_q
426
+ assert len(delays) == n_q
427
+ assert sorted(flattening) == flattening
428
+ assert sorted(delays) == delays
429
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
430
+ self.max_delay = max(delays)
431
+
432
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
433
+ """Build a flattened codebooks representation as a dictionary of inner step
434
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
435
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
436
+ """
437
+ flattened_codebooks: dict = {}
438
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
439
+ if inner_step not in flattened_codebooks:
440
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
441
+ else:
442
+ flat_codebook = flattened_codebooks[inner_step]
443
+ assert flat_codebook.delay == delay, (
444
+ "Delay and flattening between codebooks is inconsistent: ",
445
+ "two codebooks flattened to the same position should have the same delay."
446
+ )
447
+ flat_codebook.codebooks.append(q)
448
+ flattened_codebooks[inner_step] = flat_codebook
449
+ return flattened_codebooks
450
+
451
+ @property
452
+ def _num_inner_steps(self):
453
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
454
+ """
455
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
456
+
457
+ def num_virtual_steps(self, timesteps: int) -> int:
458
+ return timesteps * self._num_inner_steps + 1
459
+
460
+ def get_pattern(self, timesteps: int) -> Pattern:
461
+ """Builds pattern for delay across codebooks.
462
+
463
+ Args:
464
+ timesteps (int): Total numer of timesteps.
465
+ """
466
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
467
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
468
+ indexed_out: list = [(-1, [])]
469
+ max_timesteps = timesteps + self.max_delay
470
+ for t in range(max_timesteps):
471
+ # for each timestep, we unroll the flattened codebooks,
472
+ # emitting the sequence step with the corresponding delay
473
+ for step in range(self._num_inner_steps):
474
+ if step in self._flattened_codebooks:
475
+ # we have codebooks at this virtual step to emit
476
+ step_codebooks = self._flattened_codebooks[step]
477
+ t_for_q = t + step_codebooks.delay
478
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
479
+ if t_for_q < max_timesteps and t < max_timesteps:
480
+ indexed_out.append((t_for_q, coords))
481
+ else:
482
+ # there is no codebook in this virtual step so we emit an empty list
483
+ indexed_out.append((t, []))
484
+ out = [coords for _, coords in sorted(indexed_out)]
485
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
486
+
487
+
488
+ class VALLEPattern(CodebooksPatternProvider):
489
+ """Almost VALL-E style pattern. We futher allow some delays for the
490
+ codebooks other than the first one.
491
+
492
+ Args:
493
+ n_q (int): Number of codebooks.
494
+ delays (Optional[List[int]]): Delay for each of the codebooks.
495
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
496
+ """
497
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
498
+ super().__init__(n_q)
499
+ if delays is None:
500
+ delays = [0] * (n_q - 1)
501
+ self.delays = delays
502
+ assert len(self.delays) == self.n_q - 1
503
+ assert sorted(self.delays) == self.delays
504
+
505
+ def get_pattern(self, timesteps: int) -> Pattern:
506
+ out: PatternLayout = [[]]
507
+ for t in range(timesteps):
508
+ out.append([LayoutCoord(t, 0)])
509
+ max_delay = max(self.delays)
510
+ for t in range(timesteps + max_delay):
511
+ v = []
512
+ for q, delay in enumerate(self.delays):
513
+ t_for_q = t - delay
514
+ if t_for_q >= 0:
515
+ v.append(LayoutCoord(t_for_q, q + 1))
516
+ out.append(v)
517
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
518
+
519
+
520
+ class MusicLMPattern(CodebooksPatternProvider):
521
+ """Almost MusicLM style pattern. This is equivalent to full flattening
522
+ but in a different order.
523
+
524
+ Args:
525
+ n_q (int): Number of codebooks.
526
+ group_by (int): Number of codebooks to group together.
527
+ """
528
+ def __init__(self, n_q: int, group_by: int = 2):
529
+ super().__init__(n_q)
530
+ self.group_by = group_by
531
+
532
+ def get_pattern(self, timesteps: int) -> Pattern:
533
+ out: PatternLayout = [[]]
534
+ for offset in range(0, self.n_q, self.group_by):
535
+ for t in range(timesteps):
536
+ for q in range(offset, offset + self.group_by):
537
+ out.append([LayoutCoord(t, q)])
538
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
src/model/modules/embedding.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class TokenEmbedding(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim_model: int,
26
+ vocab_size: int,
27
+ dropout: float = 0.0,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.vocab_size = vocab_size
32
+ self.dim_model = dim_model
33
+
34
+ self.dropout = torch.nn.Dropout(p=dropout)
35
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
36
+
37
+ @property
38
+ def weight(self) -> torch.Tensor:
39
+ return self.word_embeddings.weight
40
+
41
+ def embedding(self, index: int) -> torch.Tensor:
42
+ return self.word_embeddings.weight[index : index + 1]
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ X = self.word_embeddings(x)
46
+ X = self.dropout(X)
47
+
48
+ return X
49
+
50
+
51
+ class SinePositionalEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim_model: int,
55
+ dropout: float = 0.0,
56
+ scale: bool = False,
57
+ alpha: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.dim_model = dim_model
61
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
62
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
63
+ self.dropout = torch.nn.Dropout(p=dropout)
64
+
65
+ self.reverse = False
66
+ self.pe = None
67
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
68
+
69
+ def extend_pe(self, x):
70
+ """Reset the positional encodings."""
71
+ if self.pe is not None:
72
+ if self.pe.size(1) >= x.size(1):
73
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
74
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
75
+ return
76
+ pe = torch.zeros(x.size(1), self.dim_model)
77
+ if self.reverse:
78
+ position = torch.arange(
79
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
80
+ ).unsqueeze(1)
81
+ else:
82
+ position = torch.arange(
83
+ 0, x.size(1), dtype=torch.float32
84
+ ).unsqueeze(1)
85
+ div_term = torch.exp(
86
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
87
+ * -(math.log(10000.0) / self.dim_model)
88
+ )
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ pe = pe.unsqueeze(0)
92
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ self.extend_pe(x)
96
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
97
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
98
+ return self.dropout(output)
src/model/modules/gemma.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional, Tuple
4
+ import math
5
+ from src.model.modules.kv_cache import KVCache
6
+
7
+
8
+ class GemmaConfig:
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size,
13
+ hidden_size,
14
+ intermediate_size,
15
+ num_hidden_layers,
16
+ num_attention_heads,
17
+ num_key_value_heads,
18
+ head_dim=256,
19
+ max_position_embeddings=8192,
20
+ rms_norm_eps=1e-6,
21
+ rope_theta=10000.0,
22
+ attention_bias=False,
23
+ attention_dropout=0.0,
24
+ pad_token_id=None,
25
+ **kwargs,
26
+ ):
27
+ super().__init__()
28
+ self.vocab_size = vocab_size
29
+ self.max_position_embeddings = max_position_embeddings
30
+ self.hidden_size = hidden_size
31
+ self.intermediate_size = intermediate_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.head_dim = head_dim
35
+ self.num_key_value_heads = num_key_value_heads
36
+ self.rms_norm_eps = rms_norm_eps
37
+ self.rope_theta = rope_theta
38
+ self.attention_bias = attention_bias
39
+ self.attention_dropout = attention_dropout
40
+ self.pad_token_id = pad_token_id
41
+
42
+
43
+ class GemmaRMSNorm(nn.Module):
44
+ def __init__(self, dim: int, eps: float = 1e-6):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.zeros(dim))
48
+
49
+ def _norm(self, x):
50
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
51
+
52
+ def forward(self, x):
53
+ output = self._norm(x.float())
54
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
55
+ # See https://github.com/huggingface/transformers/pull/29402
56
+ output = output * (1.0 + self.weight.float())
57
+ return output.type_as(x)
58
+
59
+
60
+ class GemmaRotaryEmbedding(nn.Module):
61
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
62
+ super().__init__()
63
+
64
+ self.dim = dim # it is set to the head_dim
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.base = base
67
+
68
+ # Calculate the theta according to the formula theta_i = base^(2i/dim) where i = 0, 1, 2, ..., dim // 2
69
+ inv_freq = 1.0 / (
70
+ self.base
71
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
72
+ )
73
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
74
+
75
+ @torch.no_grad()
76
+ def forward(self, x, position_ids, seq_len=None):
77
+ # x: [bs, num_attention_heads, seq_len, head_size]
78
+ self.inv_freq.to(x.device)
79
+ # Copy the inv_freq tensor for batch in the sequence
80
+ # inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1]
81
+ inv_freq_expanded = (
82
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
83
+ )
84
+ # position_ids_expanded: [Batch_Size, 1, Seq_Len]
85
+ position_ids_expanded = position_ids[:, None, :].float()
86
+ device_type = x.device.type
87
+ device_type = (
88
+ device_type
89
+ if isinstance(device_type, str) and device_type != "mps"
90
+ else "cpu"
91
+ )
92
+ with torch.autocast(device_type=device_type, enabled=False):
93
+ # Multiply each theta by the position (which is the argument of the sin and cos functions)
94
+ # freqs: [Batch_Size, Head_Dim // 2, 1] @ [Batch_Size, 1, Seq_Len] --> [Batch_Size, Seq_Len, Head_Dim // 2]
95
+ freqs = (
96
+ inv_freq_expanded.float() @ position_ids_expanded.float()
97
+ ).transpose(1, 2)
98
+ # emb: [Batch_Size, Seq_Len, Head_Dim]
99
+ emb = torch.cat((freqs, freqs), dim=-1)
100
+ # cos, sin: [Batch_Size, Seq_Len, Head_Dim]
101
+ cos = emb.cos()
102
+ sin = emb.sin()
103
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
104
+
105
+
106
+ def rotate_half(x):
107
+ # Build the [-x2, x1, -x4, x3, ...] tensor for the sin part of the positional encoding.
108
+ x1 = x[..., : x.shape[-1] // 2] # Takes the first half of the last dimension
109
+ x2 = x[..., x.shape[-1] // 2 :] # Takes the second half of the last dimension
110
+ return torch.cat((-x2, x1), dim=-1)
111
+
112
+
113
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
114
+ cos = cos.unsqueeze(unsqueeze_dim) # Add the head dimension
115
+ sin = sin.unsqueeze(unsqueeze_dim) # Add the head dimension
116
+ # Apply the formula (34) of the Rotary Positional Encoding paper.
117
+ q_embed = (q * cos) + (rotate_half(q) * sin)
118
+ k_embed = (k * cos) + (rotate_half(k) * sin)
119
+ return q_embed, k_embed
120
+
121
+
122
+ class GemmaMLP(nn.Module):
123
+ def __init__(self, config):
124
+ super().__init__()
125
+ self.config = config
126
+ self.hidden_size = config.hidden_size
127
+ self.intermediate_size = config.intermediate_size
128
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
129
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
130
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
131
+
132
+ def forward(self, x):
133
+ # Equivalent to:
134
+ # y = self.gate_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size]
135
+ # y = torch.gelu(y, approximate="tanh") # [Batch_Size, Seq_Len, Intermediate_Size]
136
+ # j = self.up_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size]
137
+ # z = y * j # [Batch_Size, Seq_Len, Intermediate_Size]
138
+ # z = self.down_proj(z) # [Batch_Size, Seq_Len, Intermediate_Size] -> [Batch_Size, Seq_Len, Hidden_Size]
139
+ return self.down_proj(
140
+ nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)
141
+ )
142
+
143
+
144
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
145
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
146
+ if n_rep == 1:
147
+ return hidden_states
148
+ hidden_states = hidden_states[:, :, None, :, :].expand(
149
+ batch, num_key_value_heads, n_rep, slen, head_dim
150
+ )
151
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
152
+
153
+
154
+ class GemmaAttention(nn.Module):
155
+
156
+ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
157
+ super().__init__()
158
+ self.config = config
159
+ self.layer_idx = layer_idx
160
+
161
+ self.attention_dropout = config.attention_dropout
162
+ self.hidden_size = config.hidden_size
163
+ self.num_heads = config.num_attention_heads
164
+ self.head_dim = config.head_dim
165
+ self.num_key_value_heads = config.num_key_value_heads
166
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
167
+ self.max_position_embeddings = config.max_position_embeddings
168
+ self.rope_theta = config.rope_theta
169
+ self.is_causal = True
170
+
171
+ assert self.hidden_size % self.num_heads == 0
172
+
173
+ self.q_proj = nn.Linear(
174
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
175
+ )
176
+ self.k_proj = nn.Linear(
177
+ self.hidden_size,
178
+ self.num_key_value_heads * self.head_dim,
179
+ bias=config.attention_bias,
180
+ )
181
+ self.v_proj = nn.Linear(
182
+ self.hidden_size,
183
+ self.num_key_value_heads * self.head_dim,
184
+ bias=config.attention_bias,
185
+ )
186
+ self.o_proj = nn.Linear(
187
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
188
+ )
189
+ self.rotary_emb = GemmaRotaryEmbedding(
190
+ self.head_dim,
191
+ max_position_embeddings=self.max_position_embeddings,
192
+ base=self.rope_theta,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states: torch.Tensor,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ position_ids: Optional[torch.LongTensor] = None,
200
+ kv_cache: Optional[KVCache] = None,
201
+ **kwargs,
202
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
203
+ bsz, q_len, _ = hidden_states.size() # [Batch_Size, Seq_Len, Hidden_Size]
204
+ # [Batch_Size, Seq_Len, Num_Heads_Q * Head_Dim]
205
+ query_states = self.q_proj(hidden_states)
206
+ # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
207
+ key_states = self.k_proj(hidden_states)
208
+ # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
209
+ value_states = self.v_proj(hidden_states)
210
+ # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim]
211
+ query_states = query_states.view(
212
+ bsz, q_len, self.num_heads, self.head_dim
213
+ ).transpose(1, 2)
214
+ # [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
215
+ key_states = key_states.view(
216
+ bsz, q_len, self.num_key_value_heads, self.head_dim
217
+ ).transpose(1, 2)
218
+ # [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
219
+ value_states = value_states.view(
220
+ bsz, q_len, self.num_key_value_heads, self.head_dim
221
+ ).transpose(1, 2)
222
+
223
+ # [Batch_Size, Seq_Len, Head_Dim], [Batch_Size, Seq_Len, Head_Dim]
224
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
225
+ # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim], [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
226
+ query_states, key_states = apply_rotary_pos_emb(
227
+ query_states, key_states, cos, sin
228
+ )
229
+
230
+ if kv_cache is not None:
231
+ key_states, value_states = kv_cache.update(
232
+ key_states, value_states, self.layer_idx
233
+ )
234
+
235
+ # Repeat the key and values to match the number of heads of the query
236
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
237
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
238
+ # Perform the calculation as usual, Q * K^T / sqrt(head_dim). Shape: [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV]
239
+ attn_weights = torch.matmul(
240
+ query_states, key_states.transpose(2, 3)
241
+ ) / math.sqrt(self.head_dim)
242
+
243
+ assert attention_mask is not None
244
+ attn_weights = attn_weights + attention_mask
245
+
246
+ # Apply the softmax
247
+ # [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV]
248
+ attn_weights = nn.functional.softmax(
249
+ attn_weights, dim=-1, dtype=torch.float32
250
+ ).to(query_states.dtype)
251
+ # Apply the dropout
252
+ attn_weights = nn.functional.dropout(
253
+ attn_weights, p=self.attention_dropout, training=self.training
254
+ )
255
+ # Multiply by the values. [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV] x [Batch_Size, Num_Heads_KV, Seq_Len_KV, Head_Dim] -> [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim]
256
+ attn_output = torch.matmul(attn_weights, value_states)
257
+
258
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
259
+ raise ValueError(
260
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
261
+ f" {attn_output.size()}"
262
+ )
263
+ # Make sure the sequence length is the second dimension. # [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim]
264
+ attn_output = attn_output.transpose(1, 2).contiguous()
265
+ # Concatenate all the heads together. [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q * Head_Dim]
266
+ attn_output = attn_output.view(bsz, q_len, -1)
267
+ # Multiply by W_o. [Batch_Size, Seq_Len_Q, Hidden_Size]
268
+ attn_output = self.o_proj(attn_output)
269
+
270
+ return attn_output, attn_weights
271
+
272
+
273
+ class GemmaDecoderLayer(nn.Module):
274
+
275
+ def __init__(self, config: GemmaConfig, layer_idx: int):
276
+ super().__init__()
277
+ self.hidden_size = config.hidden_size
278
+
279
+ self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
280
+
281
+ self.mlp = GemmaMLP(config)
282
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
283
+ self.post_attention_layernorm = GemmaRMSNorm(
284
+ config.hidden_size, eps=config.rms_norm_eps
285
+ )
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ kv_cache: Optional[KVCache] = None,
293
+ ) -> Tuple[
294
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
295
+ ]:
296
+ residual = hidden_states
297
+ # [Batch_Size, Seq_Len, Hidden_Size]
298
+ hidden_states = self.input_layernorm(hidden_states)
299
+
300
+ # [Batch_Size, Seq_Len, Hidden_Size]
301
+ (
302
+ hidden_states,
303
+ _,
304
+ ) = self.self_attn(
305
+ hidden_states=hidden_states,
306
+ attention_mask=attention_mask,
307
+ position_ids=position_ids,
308
+ kv_cache=kv_cache,
309
+ )
310
+ # [Batch_Size, Seq_Len, Hidden_Size]
311
+ hidden_states = residual + hidden_states
312
+
313
+ # [Batch_Size, Seq_Len, Hidden_Size]
314
+ residual = hidden_states
315
+ # [Batch_Size, Seq_Len, Hidden_Size]
316
+ hidden_states = self.post_attention_layernorm(hidden_states)
317
+ # [Batch_Size, Seq_Len, Hidden_Size]
318
+ hidden_states = self.mlp(hidden_states)
319
+ # [Batch_Size, Seq_Len, Hidden_Size]
320
+ hidden_states = residual + hidden_states
321
+
322
+ return hidden_states
323
+
324
+
325
+ class GemmaModel(nn.Module):
326
+
327
+ def __init__(self, config: GemmaConfig):
328
+ super().__init__()
329
+ self.config = config
330
+ self.padding_idx = config.pad_token_id
331
+ self.vocab_size = config.vocab_size
332
+
333
+ self.embed_tokens = nn.Embedding(
334
+ config.vocab_size, config.hidden_size, self.padding_idx
335
+ )
336
+ self.layers = nn.ModuleList(
337
+ [
338
+ GemmaDecoderLayer(config, layer_idx)
339
+ for layer_idx in range(config.num_hidden_layers)
340
+ ]
341
+ )
342
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
343
+
344
+ def get_input_embeddings(self):
345
+ return self.embed_tokens
346
+
347
+ # Ignore copy
348
+ def forward(
349
+ self,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ position_ids: Optional[torch.LongTensor] = None,
352
+ inputs_embeds: Optional[torch.FloatTensor] = None,
353
+ kv_cache: Optional[KVCache] = None,
354
+ ) -> torch.FloatTensor:
355
+ # [Batch_Size, Seq_Len, Hidden_Size]
356
+ hidden_states = inputs_embeds
357
+ # [Batch_Size, Seq_Len, Hidden_Size]
358
+ normalizer = torch.tensor(
359
+ self.config.hidden_size**0.5, dtype=hidden_states.dtype
360
+ )
361
+ hidden_states = hidden_states * normalizer
362
+
363
+ for decoder_layer in self.layers:
364
+ # [Batch_Size, Seq_Len, Hidden_Size]
365
+ hidden_states = decoder_layer(
366
+ hidden_states,
367
+ attention_mask=attention_mask,
368
+ position_ids=position_ids,
369
+ kv_cache=kv_cache,
370
+ )
371
+
372
+ # [Batch_Size, Seq_Len, Hidden_Size]
373
+ hidden_states = self.norm(hidden_states)
374
+
375
+ # [Batch_Size, Seq_Len, Hidden_Size]
376
+ return hidden_states
377
+
378
+
379
+ class GemmaForCausalLM(nn.Module):
380
+
381
+ def __init__(self, config):
382
+ super().__init__()
383
+ self.config = config
384
+ self.model = GemmaModel(config)
385
+ self.vocab_size = config.vocab_size
386
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
387
+
388
+ def get_input_embeddings(self):
389
+ return self.model.embed_tokens
390
+
391
+ def tie_weights(self):
392
+ self.lm_head.weight = self.model.embed_tokens.weight
393
+
394
+ def forward(
395
+ self,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ position_ids: Optional[torch.LongTensor] = None,
398
+ inputs_embeds: Optional[torch.FloatTensor] = None,
399
+ kv_cache: Optional[KVCache] = None,
400
+ ) -> Tuple:
401
+
402
+ # input_embeds: [Batch_Size, Seq_Len, Hidden_Size]
403
+ # outputs: [Batch_Size, Seq_Len, Hidden_Size]
404
+ outputs = self.model(
405
+ attention_mask=attention_mask,
406
+ position_ids=position_ids,
407
+ inputs_embeds=inputs_embeds,
408
+ kv_cache=kv_cache,
409
+ )
410
+
411
+ hidden_states = outputs
412
+ logits = self.lm_head(hidden_states)
413
+ logits = logits.float()
414
+
415
+ return_data = {
416
+ "logits": logits,
417
+ }
418
+
419
+ if kv_cache is not None:
420
+ # Return the updated cache
421
+ return_data["kv_cache"] = kv_cache
422
+
423
+ return return_data
src/model/modules/imagecraft.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import glob
3
+ import logging
4
+ from pathlib import Path
5
+ import os
6
+ import time
7
+ from typing import Optional, Tuple
8
+ from PIL import Image
9
+ from safetensors import safe_open
10
+ import torch
11
+ from torch import nn
12
+ import torchaudio
13
+ from src.model.modules import voicecraft
14
+ from src.model.modules.gemma import GemmaForCausalLM, KVCache
15
+ from src.model.modules.imagecraftconfig import ImageCraftConfig
16
+ from src.model.modules.imagecraftprocessor import (
17
+ ImageCraftProcessor,
18
+ )
19
+ from src.model.modules.siglip import SiglipVisionModel
20
+
21
+ from transformers import AutoTokenizer
22
+
23
+ from src.model.modules.tokenizer import (
24
+ AudioTokenizer,
25
+ TextTokenizer,
26
+ tokenize_audio,
27
+ tokenize_text,
28
+ )
29
+
30
+
31
+ from src.utils import tools
32
+ from src.utils.image_utils import is_valid_image
33
+ from src.utils.model_utils import get_config, get_model_inputs
34
+ from src.utils.util import (
35
+ replace_numbers_with_words,
36
+ sample_top_p,
37
+ save_to_buffer,
38
+ save_to_file,
39
+ split_line_to_sentences,
40
+ )
41
+
42
+ from huggingface_hub import HfApi
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class ImageCraftMultiModalProjector(nn.Module):
48
+ def __init__(self, config: ImageCraftConfig):
49
+ super().__init__()
50
+ self.linear = nn.Linear(
51
+ config.vision_config.hidden_size,
52
+ config.vision_config.projection_dim,
53
+ bias=True,
54
+ )
55
+
56
+ def forward(self, image_features):
57
+ hidden_states = self.linear(image_features)
58
+ return hidden_states
59
+
60
+
61
+ class ImageCraft(nn.Module):
62
+ config_class = ImageCraftConfig
63
+
64
+ def __init__(self, config: ImageCraftConfig):
65
+ super(ImageCraft, self).__init__()
66
+ self.config = config
67
+ self.vision_tower = SiglipVisionModel(config.vision_config)
68
+ self.multi_modal_projector = ImageCraftMultiModalProjector(config)
69
+ self.vocab_size = config.text_config.vocab_size
70
+
71
+ self.language_model = GemmaForCausalLM(config.text_config)
72
+
73
+ self.pad_token_id = (
74
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
75
+ )
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(
78
+ "google/paligemma-3b-pt-224", padding_side="right"
79
+ )
80
+ assert tokenizer.padding_side == "right"
81
+
82
+ num_image_tokens = config.vision_config.num_image_tokens
83
+ image_size = config.vision_config.image_size
84
+ self.processor = ImageCraftProcessor(tokenizer, num_image_tokens, image_size)
85
+
86
+ self.text_tokenizer = None
87
+
88
+ self.voicecraft_model = None
89
+ self.audio_tokenizer = None
90
+
91
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
92
+
93
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
94
+ def tie_weights(self):
95
+ return self.language_model.tie_weights()
96
+
97
+ def forward(
98
+ self,
99
+ input_ids: torch.LongTensor = None,
100
+ pixel_values: torch.FloatTensor = None,
101
+ attention_mask: Optional[torch.Tensor] = None,
102
+ labels: Optional[torch.LongTensor] = None,
103
+ kv_cache: Optional[KVCache] = None,
104
+ ) -> Tuple:
105
+ # Make sure the input is right-padded
106
+ assert torch.all(attention_mask == 1), "The input cannot be padded"
107
+
108
+ # 1. Extra the input embeddings
109
+ # shape: (Batch_Size, Seq_Len, Hidden_Size)
110
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
111
+
112
+ # 2. Merge text and images
113
+ # [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
114
+ selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
115
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
116
+ image_features = self.multi_modal_projector(selected_image_feature)
117
+
118
+ # Merge the embeddings of the text tokens and the image tokens
119
+ inputs_embeds, attention_mask, position_ids = (
120
+ self._merge_input_ids_with_image_features(
121
+ image_features, inputs_embeds, input_ids, attention_mask, kv_cache
122
+ )
123
+ )
124
+
125
+ outputs = self.language_model(
126
+ attention_mask=attention_mask,
127
+ position_ids=position_ids,
128
+ inputs_embeds=inputs_embeds,
129
+ kv_cache=kv_cache,
130
+ )
131
+
132
+ return outputs
133
+
134
+ def _merge_input_ids_with_image_features(
135
+ self,
136
+ image_features: torch.Tensor,
137
+ inputs_embeds: torch.Tensor,
138
+ input_ids: torch.Tensor,
139
+ attention_mask: torch.Tensor,
140
+ kv_cache: Optional[KVCache] = None,
141
+ ):
142
+ _, _, embed_dim = image_features.shape
143
+ batch_size, sequence_length = input_ids.shape
144
+ dtype, device = inputs_embeds.dtype, inputs_embeds.device
145
+ # Shape: [Batch_Size, Seq_Len, Hidden_Size]
146
+ scaled_image_features = image_features / (self.config.hidden_size**0.5)
147
+
148
+ # Combine the embeddings of the image tokens, the text tokens and mask out all the padding tokens.
149
+ final_embedding = torch.zeros(
150
+ batch_size,
151
+ sequence_length,
152
+ embed_dim,
153
+ dtype=inputs_embeds.dtype,
154
+ device=inputs_embeds.device,
155
+ )
156
+ # Shape: [Batch_Size, Seq_Len]. True for text tokens
157
+ text_mask = (input_ids != self.config.image_token_index) & (
158
+ input_ids != self.pad_token_id
159
+ )
160
+ # Shape: [Batch_Size, Seq_Len]. True for image tokens
161
+ image_mask = input_ids == self.config.image_token_index
162
+ # Shape: [Batch_Size, Seq_Len]. True for padding tokens
163
+ pad_mask = input_ids == self.pad_token_id
164
+
165
+ # We need to expand the masks to the embedding dimension otherwise we can't use them in torch.where
166
+ text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
167
+ pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
168
+ image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
169
+
170
+ # Add the text embeddings
171
+ final_embedding = torch.where(
172
+ text_mask_expanded, inputs_embeds, final_embedding
173
+ )
174
+ # Insert image embeddings. We can't use torch.where because the sequence length of scaled_image_features is not equal to the sequence length of the final embedding
175
+ final_embedding = final_embedding.masked_scatter(
176
+ image_mask_expanded, scaled_image_features
177
+ )
178
+ # Zero out padding tokens
179
+ final_embedding = torch.where(
180
+ pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding
181
+ )
182
+
183
+ #### CREATE THE ATTENTION MASK ####
184
+
185
+ dtype, device = inputs_embeds.dtype, inputs_embeds.device
186
+ min_dtype = torch.finfo(dtype).min
187
+ q_len = inputs_embeds.shape[1]
188
+
189
+ if kv_cache is None or kv_cache.num_items() == 0:
190
+ # Do not mask any token, because we're in the prefill phase
191
+ # This only works when we have no padding
192
+ causal_mask = torch.full(
193
+ (batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device
194
+ )
195
+ else:
196
+ # Since we are generating tokens, the query must be one single token
197
+ assert q_len == 1
198
+ kv_len = kv_cache.num_items() + q_len
199
+ # Also in this case we don't need to mask anything, since each query should be able to attend all previous tokens.
200
+ # This only works when we have no padding
201
+ causal_mask = torch.full(
202
+ (batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device
203
+ )
204
+
205
+ # Add the head dimension
206
+ # [Batch_Size, Q_Len, KV_Len] -> [Batch_Size, Num_Heads_Q, Q_Len, KV_Len]
207
+ causal_mask = causal_mask.unsqueeze(1)
208
+
209
+ if kv_cache is not None and kv_cache.num_items() > 0:
210
+ # The position of the query is just the last position
211
+ position_ids = attention_mask.cumsum(-1)[:, -1]
212
+ if position_ids.dim() == 1:
213
+ position_ids = position_ids.unsqueeze(0)
214
+ else:
215
+ # Create a position_ids based on the size of the attention_mask
216
+ # For masked tokens, use the number 1 as position.
217
+ position_ids = (
218
+ (attention_mask.cumsum(-1))
219
+ .masked_fill_((attention_mask == 0), 1)
220
+ .to(device)
221
+ )
222
+
223
+ return final_embedding, causal_mask, position_ids
224
+
225
+ def _generate_caption(self, image, max_tokens=100, do_sample=False):
226
+ prompt = "caption en"
227
+ image = (
228
+ image.convert("RGB")
229
+ if is_valid_image(image)
230
+ else Image.open(image).convert("RGB")
231
+ )
232
+
233
+ inputs = get_model_inputs(
234
+ processor=self.processor, prompt=prompt, image=image, device=self.device
235
+ )
236
+
237
+ image.close()
238
+
239
+ input_ids = inputs["input_ids"]
240
+ attention_mask = inputs["attention_mask"]
241
+ pixel_values = inputs["pixel_values"]
242
+
243
+ kv_cache = KVCache()
244
+
245
+ stop_token = self.processor.tokenizer.eos_token_id
246
+ generated_tokens = []
247
+
248
+ for _ in range(max_tokens):
249
+ outputs = self(
250
+ input_ids=input_ids,
251
+ pixel_values=pixel_values,
252
+ attention_mask=attention_mask,
253
+ kv_cache=kv_cache,
254
+ )
255
+ kv_cache = outputs["kv_cache"]
256
+ next_token_logits = outputs["logits"][:, -1, :]
257
+ if do_sample:
258
+ next_token_logits = torch.softmax(
259
+ next_token_logits / self.config.temperature, dim=-1
260
+ )
261
+ next_token = sample_top_p(next_token_logits, self.config.top_p)
262
+ else:
263
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
264
+ assert next_token.size() == (1, 1)
265
+ next_token = next_token.squeeze(0)
266
+ generated_tokens.append(next_token)
267
+ if next_token.item() == stop_token:
268
+ break
269
+ input_ids = next_token.unsqueeze(-1)
270
+ attention_mask = torch.cat(
271
+ [attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1
272
+ )
273
+
274
+ generated_tokens = torch.cat(generated_tokens, dim=-1)
275
+ decoded_text = self.processor.tokenizer.decode(
276
+ generated_tokens, skip_special_tokens=True
277
+ )
278
+ decoded_text = (
279
+ parts[1] if len(parts := decoded_text.split("\n", 1)) > 1 else decoded_text
280
+ )
281
+
282
+ return decoded_text.rstrip(" .").strip().capitalize() + "."
283
+
284
+ def _generate_speech(self, text: str, output_type="file"):
285
+
286
+ sentences = split_line_to_sentences(text)
287
+
288
+ voice_audio = (
289
+ f"media/voicecraft/voices/{self.config.voicecraft_config.voice_audio_path}"
290
+ )
291
+ voice_transcript = self.config.voicecraft_config.voice_audio_transcript
292
+ cut_off_sec = self.config.voicecraft_config.cut_off_sec
293
+
294
+ decode_config = {
295
+ "top_k": self.config.voicecraft_config.top_k,
296
+ "top_p": self.config.voicecraft_config.top_p,
297
+ "temperature": self.config.voicecraft_config.temperature,
298
+ "stop_repetition": self.config.voicecraft_config.stop_repetition,
299
+ "kvcache": self.config.voicecraft_config.kvcache,
300
+ "codec_audio_sr": self.config.voicecraft_config.codec_audio_sr,
301
+ "codec_sr": self.config.voicecraft_config.codec_sr,
302
+ "silence_tokens": self.config.voicecraft_config.silence_tokens,
303
+ "sample_batch_size": self.config.voicecraft_config.sample_batch_size,
304
+ }
305
+
306
+ info = torchaudio.info(voice_audio)
307
+ audio_dur = info.num_frames / info.sample_rate
308
+ prompt_end_frame = int(min(audio_dur, cut_off_sec) * info.sample_rate)
309
+
310
+ audio_tensors = []
311
+ transcript = voice_transcript
312
+
313
+ for sentence in sentences:
314
+
315
+ transcript += sentence + "\n"
316
+ transcript = replace_numbers_with_words(transcript).replace(" ", " ")
317
+
318
+ # phonemize
319
+ phn2num = self.voicecraft_model.args.phn2num
320
+ text_tokens = [
321
+ phn2num[phn]
322
+ for phn in tokenize_text(self.text_tokenizer, text=transcript.strip())
323
+ if phn in phn2num
324
+ ]
325
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
326
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
327
+
328
+ # encode audio
329
+ encoded_frames = tokenize_audio(
330
+ self.audio_tokenizer,
331
+ voice_audio,
332
+ offset=0,
333
+ num_frames=prompt_end_frame,
334
+ )
335
+ original_audio = encoded_frames[0][0].transpose(2, 1) # [1,T,K]
336
+ model_args = vars(self.voicecraft_model.args)
337
+ model_args = Namespace(**model_args)
338
+
339
+ assert (
340
+ original_audio.ndim == 3
341
+ and original_audio.shape[0] == 1
342
+ and original_audio.shape[2] == model_args.n_codebooks
343
+ ), original_audio.shape
344
+
345
+ # forward
346
+ stime = time.time()
347
+ if decode_config["sample_batch_size"] <= 1:
348
+ _, gen_frames = self.voicecraft_model.inference_tts(
349
+ text_tokens.to(self.device),
350
+ text_tokens_lens.to(self.device),
351
+ original_audio[..., : model_args.n_codebooks].to(
352
+ self.device
353
+ ), # [1,T,8]
354
+ top_k=decode_config["top_k"],
355
+ top_p=decode_config["top_p"],
356
+ temperature=decode_config["temperature"],
357
+ stop_repetition=decode_config["stop_repetition"],
358
+ kvcache=decode_config["kvcache"],
359
+ silence_tokens=(
360
+ eval(decode_config["silence_tokens"])
361
+ if type(decode_config["silence_tokens"]) == str
362
+ else decode_config["silence_tokens"]
363
+ ),
364
+ ) # output is [1,K,T]
365
+ else:
366
+ _, gen_frames = self.voicecraft_model.inference_tts_batch(
367
+ text_tokens.to(self.device),
368
+ text_tokens_lens.to(self.device),
369
+ original_audio[..., : model_args.n_codebooks].to(
370
+ self.device
371
+ ), # [1,T,8]
372
+ top_k=decode_config["top_k"],
373
+ top_p=decode_config["top_p"],
374
+ temperature=decode_config["temperature"],
375
+ stop_repetition=decode_config["stop_repetition"],
376
+ kvcache=decode_config["kvcache"],
377
+ batch_size=decode_config["sample_batch_size"],
378
+ silence_tokens=(
379
+ eval(decode_config["silence_tokens"])
380
+ if type(decode_config["silence_tokens"]) == str
381
+ else decode_config["silence_tokens"]
382
+ ),
383
+ ) # output is [1,K,T]
384
+ gen_sample = self.audio_tokenizer.decode([(gen_frames, None)])
385
+ gen_audio = gen_sample[0].cpu()
386
+ audio_tensors.append(gen_audio)
387
+
388
+ output = None
389
+
390
+ if output_type == "file":
391
+ output = save_to_file(audio_tensors, decode_config["codec_audio_sr"])
392
+ else:
393
+ output = save_to_buffer(audio_tensors, decode_config["codec_audio_sr"])
394
+
395
+ # Empty cuda cache between runs
396
+ if torch.cuda.is_available():
397
+ torch.cuda.empty_cache()
398
+
399
+ return output
400
+
401
+ @torch.inference_mode()
402
+ def generate(
403
+ self,
404
+ image,
405
+ max_tokens=30,
406
+ do_sample=False,
407
+ output_type="file",
408
+ return_output="speech",
409
+ ):
410
+ if return_output == "speech" or return_output is None:
411
+ transcript = self._generate_caption(image, max_tokens, do_sample)
412
+ speech = self._generate_speech(transcript, output_type)
413
+ return transcript, speech
414
+ else:
415
+ transcript = self._generate_caption(image, max_tokens, do_sample)
416
+ return transcript
417
+
418
+ @classmethod
419
+ def from_pretrained(
420
+ cls,
421
+ model_path=None,
422
+ ):
423
+ api = HfApi()
424
+
425
+ device = "cuda" if torch.cuda.is_available() else "cpu"
426
+
427
+ env_config = tools.load_config()
428
+ pretrained_dir = env_config["pretrained_dir"]
429
+ imagecraft_cache_dir = f"{pretrained_dir}/imagecraft"
430
+ voicecraft_cache_dir = f"{pretrained_dir}/voicecraft"
431
+
432
+ state_dict = {}
433
+
434
+ if Path(model_path).is_file():
435
+ checkpoint = torch.load(model_path, weights_only=False)
436
+ state_dict = checkpoint["state_dict"]
437
+
438
+ else:
439
+
440
+ model_path = api.snapshot_download(
441
+ repo_id=model_path,
442
+ repo_type="model",
443
+ cache_dir=imagecraft_cache_dir,
444
+ local_files_only=False,
445
+ )
446
+
447
+ safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
448
+
449
+ for safetensors_file in safetensors_files:
450
+ with safe_open(safetensors_file, framework="pt", device="cpu") as f:
451
+ for key in f.keys():
452
+ state_dict[key] = f.get_tensor(key)
453
+
454
+ imagecraft_config = get_config()
455
+
456
+ model = cls(imagecraft_config).to(device)
457
+
458
+ # Load the state dict of the model
459
+ model.load_state_dict(state_dict, strict=False)
460
+
461
+ # Tie weights
462
+ model.tie_weights()
463
+
464
+ model = model.eval()
465
+
466
+ # Load voicecraft module
467
+
468
+ model.voicecraft_model = voicecraft.VoiceCraft.from_pretrained(
469
+ f"pyp1/VoiceCraft_{model.config.voicecraft_config.model_name.replace('.pth', '')}",
470
+ cache_dir=voicecraft_cache_dir,
471
+ )
472
+
473
+ encodec_fn = f"{voicecraft_cache_dir}/{model.config.voicecraft_config.encodec}"
474
+
475
+ if not os.path.exists(encodec_fn):
476
+ os.system(
477
+ f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{model.config.voicecraft_config.encodec}"
478
+ )
479
+ os.system(f"mv {model.config.voicecraft_config.encodec} {encodec_fn}")
480
+
481
+ model.audio_tokenizer = AudioTokenizer(
482
+ signature=encodec_fn,
483
+ device=device,
484
+ )
485
+
486
+ model.text_tokenizer = TextTokenizer(backend="espeak")
487
+
488
+ model.voicecraft_model.to(device)
489
+
490
+ return model
src/model/modules/imagecraftconfig.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from src.model.modules.gemma import GemmaConfig
2
+ # from src.model.modules.siglip import SiglipVisionConfig
3
+ from src.model.modules.voicecraftconfig import VoiceCraftConfig
4
+
5
+ from transformers import SiglipVisionConfig, GemmaConfig, PretrainedConfig
6
+
7
+
8
+ class ImageCraftConfig(PretrainedConfig):
9
+
10
+ model_type = "imagecraft"
11
+
12
+ def __init__(
13
+ self,
14
+ vision_config=None,
15
+ text_config=None,
16
+ voicecraft_config=None,
17
+ ignore_index=-100,
18
+ image_token_index=256000,
19
+ vocab_size=257152,
20
+ projection_dim=2048,
21
+ hidden_size=2048,
22
+ pad_token_id=None,
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+ self.ignore_index = ignore_index
27
+ self.image_token_index = image_token_index
28
+ self.vocab_size = vocab_size
29
+ self.projection_dim = projection_dim
30
+ self.hidden_size = hidden_size
31
+ self.is_encoder_decoder = False
32
+
33
+ self.pad_token_id = pad_token_id if pad_token_id is not None else -1
34
+
35
+ self.vision_config = SiglipVisionConfig(**vision_config)
36
+
37
+ self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
38
+ self.vocab_size = self.text_config.vocab_size
39
+
40
+ self.text_config.num_image_tokens = (
41
+ self.vision_config.image_size // self.vision_config.patch_size
42
+ ) ** 2
43
+ self.vision_config.projection_dim = projection_dim
44
+
45
+ self.voicecraft_config = VoiceCraftConfig(**voicecraft_config)
46
+
47
+ super().__init__(**kwargs)
src/model/modules/imagecraftprocessor.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ from src.utils.util import (
8
+ IMAGENET_STANDARD_MEAN,
9
+ IMAGENET_STANDARD_STD,
10
+ add_image_tokens_to_prompt,
11
+ process_images,
12
+ )
13
+
14
+ from transformers import SiglipImageProcessor
15
+
16
+
17
+ class ImageCraftProcessor:
18
+
19
+ IMAGE_TOKEN = "<image>"
20
+
21
+ def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
22
+ super().__init__()
23
+
24
+ self.image_seq_length = num_image_tokens
25
+ self.image_size = image_size
26
+
27
+ # Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
28
+ tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
29
+ tokenizer.add_special_tokens(tokens_to_add)
30
+ EXTRA_TOKENS = [
31
+ f"<loc{i:04d}>" for i in range(1024)
32
+ ] # These tokens are used for object detection (bounding boxes)
33
+ EXTRA_TOKENS += [
34
+ f"<seg{i:03d}>" for i in range(128)
35
+ ] # These tokens are used for object segmentation
36
+ tokenizer.add_tokens(EXTRA_TOKENS)
37
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
38
+ # We will add the BOS and EOS tokens ourselves
39
+ tokenizer.add_bos_token = False
40
+ tokenizer.add_eos_token = False
41
+
42
+ self.tokenizer = tokenizer
43
+ # self.image_processor = SiglipImageProcessor.from_pretrained(
44
+ # "google/siglip-so400m-patch14-384"
45
+ # )
46
+
47
+ def __call__(
48
+ self,
49
+ text: List[str],
50
+ images: List[Image.Image],
51
+ padding: str = "longest",
52
+ truncation: bool = True,
53
+ ) -> dict:
54
+ assert (
55
+ len(images) == 1 and len(text) == 1
56
+ ), f"Received {len(images)} images for {len(text)} prompts."
57
+
58
+ # pixel_values = self.image_processor(images=images, return_tensors="pt")[
59
+ # "pixel_values"
60
+ # ]
61
+ pixel_values = process_images(
62
+ images,
63
+ size=(self.image_size, self.image_size),
64
+ resample=Image.Resampling.BICUBIC,
65
+ rescale_factor=1 / 255.0,
66
+ image_mean=IMAGENET_STANDARD_MEAN,
67
+ image_std=IMAGENET_STANDARD_STD,
68
+ )
69
+ # Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
70
+ pixel_values = np.stack(pixel_values, axis=0)
71
+ # Convert the numpy array to a PyTorch tensor
72
+ pixel_values = torch.tensor(pixel_values)
73
+
74
+ input_strings = [
75
+ add_image_tokens_to_prompt(
76
+ prefix_prompt=prompt,
77
+ bos_token=self.tokenizer.bos_token,
78
+ image_seq_length=self.image_seq_length,
79
+ image_token=self.IMAGE_TOKEN,
80
+ )
81
+ for prompt in text
82
+ ]
83
+
84
+ # max_length += self.image_seq_length
85
+
86
+ inputs = self.tokenizer(
87
+ input_strings,
88
+ return_tensors="pt",
89
+ padding=padding,
90
+ max_length=512,
91
+ truncation=truncation,
92
+ )
93
+
94
+ return_data = {"pixel_values": pixel_values, **inputs}
95
+
96
+ return return_data
src/model/modules/kv_cache.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import torch
3
+
4
+ class KVCache:
5
+
6
+ def __init__(self) -> None:
7
+ self.key_cache: List[torch.Tensor] = []
8
+ self.value_cache: List[torch.Tensor] = []
9
+
10
+ def num_items(self) -> int:
11
+ if len(self.key_cache) == 0:
12
+ return 0
13
+ else:
14
+ # The shape of the key_cache is [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
15
+ return self.key_cache[0].shape[-2]
16
+
17
+ def update(
18
+ self,
19
+ key_states: torch.Tensor,
20
+ value_states: torch.Tensor,
21
+ layer_idx: int,
22
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ if len(self.key_cache) <= layer_idx:
24
+ # If we never added anything to the KV-Cache of this layer, let's create it.
25
+ self.key_cache.append(key_states)
26
+ self.value_cache.append(value_states)
27
+ else:
28
+ # ... otherwise we concatenate the new keys with the existing ones.
29
+ # each tensor has shape: [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
30
+ self.key_cache[layer_idx] = torch.cat(
31
+ [self.key_cache[layer_idx], key_states], dim=-2
32
+ )
33
+ self.value_cache[layer_idx] = torch.cat(
34
+ [self.value_cache[layer_idx], value_states], dim=-2
35
+ )
36
+
37
+ # ... and then we return all the existing keys + the new ones.
38
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
src/model/modules/sampling.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/jasonppy/VoiceCraft/blob/master/models/modules/sampling.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ def top_k_top_p_filtering(
7
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
8
+ ):
9
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
10
+ Args:
11
+ logits: logits distribution shape (batch size, vocabulary size)
12
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
13
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
14
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
15
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
16
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
17
+ """
18
+ if top_k > 0:
19
+ top_k = min(
20
+ max(top_k, min_tokens_to_keep), logits.size(-1)
21
+ ) # Safety check
22
+ # Remove all tokens with a probability less than the last token of the top-k
23
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
24
+ logits[indices_to_remove] = filter_value
25
+
26
+ if top_p < 1.0:
27
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
28
+ cumulative_probs = torch.cumsum(
29
+ F.softmax(sorted_logits, dim=-1), dim=-1
30
+ )
31
+
32
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
33
+ sorted_indices_to_remove = cumulative_probs > top_p
34
+ if min_tokens_to_keep > 1:
35
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
36
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
37
+ # Shift the indices to the right to keep also the first token above the threshold
38
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
39
+ ..., :-1
40
+ ].clone()
41
+ sorted_indices_to_remove[..., 0] = 0
42
+
43
+ # scatter sorted tensors to original indexing
44
+ indices_to_remove = sorted_indices_to_remove.scatter(
45
+ 1, sorted_indices, sorted_indices_to_remove
46
+ )
47
+ logits[indices_to_remove] = filter_value
48
+ return logits
49
+
50
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
51
+ # temperature: (`optional`) float
52
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
53
+ # top_k: (`optional`) int
54
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
55
+ # top_p: (`optional`) float
56
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
57
+
58
+ # Temperature (higher temperature => more likely to sample low probability tokens)
59
+ if temperature != 1.0:
60
+ logits = logits / temperature
61
+ # Top-p/top-k filtering
62
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
63
+ # Sample
64
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
65
+ return token
src/model/modules/scaling.py ADDED
@@ -0,0 +1,1391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
2
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import collections
20
+ import logging
21
+ import random
22
+ import math
23
+ from functools import reduce
24
+ from itertools import repeat
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch import Tensor
31
+ from torch.nn import Embedding as ScaledEmbedding
32
+
33
+ # from valle.utils import Transpose
34
+
35
+ class Transpose(nn.Identity):
36
+ """(N, T, D) -> (N, D, T)"""
37
+
38
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
39
+ return input.transpose(1, 2)
40
+
41
+ class ActivationBalancerFunction(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ x: Tensor,
46
+ scale_factor: Tensor,
47
+ sign_factor: Optional[Tensor],
48
+ channel_dim: int,
49
+ ) -> Tensor:
50
+ if channel_dim < 0:
51
+ channel_dim += x.ndim
52
+ ctx.channel_dim = channel_dim
53
+ xgt0 = x > 0
54
+ if sign_factor is None:
55
+ ctx.save_for_backward(xgt0, scale_factor)
56
+ else:
57
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
58
+ return x
59
+
60
+ @staticmethod
61
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
62
+ if len(ctx.saved_tensors) == 3:
63
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ sign_factor = sign_factor.unsqueeze(-1)
67
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
68
+ else:
69
+ xgt0, scale_factor = ctx.saved_tensors
70
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
71
+ scale_factor = scale_factor.unsqueeze(-1)
72
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
73
+ neg_delta_grad = x_grad.abs() * factor
74
+ return (
75
+ x_grad - neg_delta_grad,
76
+ None,
77
+ None,
78
+ None,
79
+ )
80
+
81
+
82
+ def _compute_scale_factor(
83
+ x: Tensor,
84
+ channel_dim: int,
85
+ min_abs: float,
86
+ max_abs: float,
87
+ gain_factor: float,
88
+ max_factor: float,
89
+ ) -> Tensor:
90
+ if channel_dim < 0:
91
+ channel_dim += x.ndim
92
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
93
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
94
+
95
+ if min_abs == 0.0:
96
+ below_threshold = 0.0
97
+ else:
98
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
99
+ # x_abs)_mean , min_abs.
100
+ below_threshold = (
101
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
102
+ ).clamp(min=0, max=max_factor)
103
+
104
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
105
+ min=0, max=max_factor
106
+ )
107
+
108
+ return below_threshold - above_threshold
109
+
110
+
111
+ def _compute_sign_factor(
112
+ x: Tensor,
113
+ channel_dim: int,
114
+ min_positive: float,
115
+ max_positive: float,
116
+ gain_factor: float,
117
+ max_factor: float,
118
+ ) -> Tensor:
119
+ if channel_dim < 0:
120
+ channel_dim += x.ndim
121
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
122
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
123
+ if min_positive == 0.0:
124
+ factor1 = 0.0
125
+ else:
126
+ # 0 if proportion_positive >= min_positive, else can be
127
+ # as large as max_factor.
128
+ factor1 = (
129
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
130
+ ).clamp_(min=0, max=max_factor)
131
+
132
+ if max_positive == 1.0:
133
+ factor2 = 0.0
134
+ else:
135
+ # 0 if self.proportion_positive <= max_positive, else can be
136
+ # as large as -max_factor.
137
+ factor2 = (
138
+ (proportion_positive - max_positive)
139
+ * (gain_factor / (1.0 - max_positive))
140
+ ).clamp_(min=0, max=max_factor)
141
+ sign_factor = factor1 - factor2
142
+ # require min_positive != 0 or max_positive != 1:
143
+ assert not isinstance(sign_factor, float)
144
+ return sign_factor
145
+
146
+
147
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
148
+ """
149
+ This object is used in class ActivationBalancer when the user specified
150
+ min_positive=0, max_positive=1, so there are no constraints on the signs
151
+ of the activations and only the absolute value has a constraint.
152
+ """
153
+
154
+ @staticmethod
155
+ def forward(
156
+ ctx,
157
+ x: Tensor,
158
+ sign_factor: Tensor,
159
+ scale_factor: Tensor,
160
+ channel_dim: int,
161
+ ) -> Tensor:
162
+ if channel_dim < 0:
163
+ channel_dim += x.ndim
164
+ ctx.channel_dim = channel_dim
165
+ xgt0 = x > 0
166
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
167
+ return x
168
+
169
+ @staticmethod
170
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
171
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
172
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
173
+ sign_factor = sign_factor.unsqueeze(-1)
174
+ scale_factor = scale_factor.unsqueeze(-1)
175
+
176
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
177
+ neg_delta_grad = x_grad.abs() * factor
178
+ return (
179
+ x_grad - neg_delta_grad,
180
+ None,
181
+ None,
182
+ None,
183
+ )
184
+
185
+
186
+ class RandomClampFunction(torch.autograd.Function):
187
+ @staticmethod
188
+ def forward(
189
+ ctx,
190
+ x: Tensor,
191
+ min: Optional[float],
192
+ max: Optional[float],
193
+ prob: float,
194
+ reflect: float,
195
+ ) -> Tensor:
196
+ x_clamped = torch.clamp(x, min=min, max=max)
197
+ mask = torch.rand_like(x) < prob
198
+ ans = torch.where(mask, x_clamped, x)
199
+ if x.requires_grad:
200
+ ctx.save_for_backward(ans == x)
201
+ ctx.reflect = reflect
202
+ if reflect != 0.0:
203
+ ans = ans * (1.0 + reflect) - (x * reflect)
204
+ return ans
205
+
206
+ @staticmethod
207
+ def backward(
208
+ ctx, ans_grad: Tensor
209
+ ) -> Tuple[Tensor, None, None, None, None]:
210
+ (is_same,) = ctx.saved_tensors
211
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
212
+ reflect = ctx.reflect
213
+ if reflect != 0.0:
214
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
215
+ return x_grad, None, None, None, None
216
+
217
+
218
+ def random_clamp(
219
+ x: Tensor,
220
+ min: Optional[float] = None,
221
+ max: Optional[float] = None,
222
+ prob: float = 0.5,
223
+ reflect: float = 0.0,
224
+ ):
225
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
226
+
227
+
228
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
229
+ """
230
+ A randomized way of casting a floating point value to half precision.
231
+ """
232
+ if x.dtype == torch.float16:
233
+ return x
234
+ x_abs = x.abs()
235
+ is_too_small = x_abs < min_abs
236
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
237
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
238
+ # for those elements].
239
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
240
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
241
+
242
+
243
+ class RandomGradFunction(torch.autograd.Function):
244
+ """
245
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
246
+ randomized approach that preserves expectations (intended to reduce roundoff).
247
+ """
248
+
249
+ @staticmethod
250
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
251
+ ctx.min_abs = min_abs
252
+ return x
253
+
254
+ @staticmethod
255
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
256
+ if ans_grad.dtype == torch.float16:
257
+ return (
258
+ random_cast_to_half(
259
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
260
+ ),
261
+ None,
262
+ )
263
+ else:
264
+ return ans_grad, None
265
+
266
+
267
+ class RandomGrad(torch.nn.Module):
268
+ """
269
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
270
+ accuracy of training when using amp (automatic mixed precision)
271
+ """
272
+
273
+ def __init__(self, min_abs: float = 5.0e-06):
274
+ super(RandomGrad, self).__init__()
275
+ self.min_abs = min_abs
276
+
277
+ def forward(self, x: Tensor):
278
+ if (
279
+ torch.jit.is_scripting()
280
+ or not self.training
281
+ or torch.jit.is_tracing()
282
+ ):
283
+ return x
284
+ else:
285
+ return RandomGradFunction.apply(x, self.min_abs)
286
+
287
+
288
+ class SoftmaxFunction(torch.autograd.Function):
289
+ """
290
+ Tries to handle half-precision derivatives in a randomized way that should
291
+ be more accurate for training than the default behavior.
292
+ """
293
+
294
+ @staticmethod
295
+ def forward(ctx, x: Tensor, dim: int):
296
+ ans = x.softmax(dim=dim)
297
+ # if x dtype is float16, x.softmax() returns a float32 because
298
+ # (presumably) that op does not support float16, and autocast
299
+ # is enabled.
300
+ if torch.is_autocast_enabled():
301
+ ans = ans.to(torch.float16)
302
+ ctx.save_for_backward(ans)
303
+ ctx.x_dtype = x.dtype
304
+ ctx.dim = dim
305
+ return ans
306
+
307
+ @staticmethod
308
+ def backward(ctx, ans_grad: Tensor):
309
+ (ans,) = ctx.saved_tensors
310
+ with torch.cuda.amp.autocast(enabled=False):
311
+ ans_grad = ans_grad.to(torch.float32)
312
+ ans = ans.to(torch.float32)
313
+ x_grad = ans_grad * ans
314
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
315
+ return x_grad, None
316
+
317
+
318
+ def softmax(x: Tensor, dim: int):
319
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
320
+ return x.softmax(dim)
321
+
322
+ return SoftmaxFunction.apply(x, dim)
323
+
324
+
325
+ class MaxEigLimiterFunction(torch.autograd.Function):
326
+ @staticmethod
327
+ def forward(
328
+ ctx,
329
+ x: Tensor,
330
+ coeffs: Tensor,
331
+ direction: Tensor,
332
+ channel_dim: int,
333
+ grad_scale: float,
334
+ ) -> Tensor:
335
+ ctx.channel_dim = channel_dim
336
+ ctx.grad_scale = grad_scale
337
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
338
+ return x
339
+
340
+ @staticmethod
341
+ def backward(ctx, x_grad, *args):
342
+ with torch.enable_grad():
343
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
344
+ x_orig.requires_grad = True
345
+ num_channels = x_orig.shape[ctx.channel_dim]
346
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
347
+ new_direction.requires_grad = False
348
+ x = x - x.mean(dim=0)
349
+ x_var = (x ** 2).mean()
350
+ x_residual = x - coeffs * new_direction
351
+ x_residual_var = (x_residual ** 2).mean()
352
+ # `variance_proportion` is the proportion of the variance accounted for
353
+ # by the top eigen-direction. This is to be minimized.
354
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
355
+ variance_proportion.backward()
356
+ x_orig_grad = x_orig.grad
357
+ x_extra_grad = (
358
+ x_orig.grad
359
+ * ctx.grad_scale
360
+ * x_grad.norm()
361
+ / (x_orig_grad.norm() + 1.0e-20)
362
+ )
363
+ return x_grad + x_extra_grad.detach(), None, None, None, None
364
+
365
+
366
+ class BasicNorm(torch.nn.Module):
367
+ """
368
+ This is intended to be a simpler, and hopefully cheaper, replacement for
369
+ LayerNorm. The observation this is based on, is that Transformer-type
370
+ networks, especially with pre-norm, sometimes seem to set one of the
371
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
372
+ the LayerNorm because the output magnitude is then not strongly dependent
373
+ on the other (useful) features. Presumably the weight and bias of the
374
+ LayerNorm are required to allow it to do this.
375
+ So the idea is to introduce this large constant value as an explicit
376
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
377
+ doesn't have to do this trick. We make the "eps" learnable.
378
+ Args:
379
+ num_channels: the number of channels, e.g. 512.
380
+ channel_dim: the axis/dimension corresponding to the channel,
381
+ interprted as an offset from the input's ndim if negative.
382
+ shis is NOT the num_channels; it should typically be one of
383
+ {-2, -1, 0, 1, 2, 3}.
384
+ eps: the initial "epsilon" that we add as ballast in:
385
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
386
+ Note: our epsilon is actually large, but we keep the name
387
+ to indicate the connection with conventional LayerNorm.
388
+ learn_eps: if true, we learn epsilon; if false, we keep it
389
+ at the initial value.
390
+ eps_min: float
391
+ eps_max: float
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ num_channels: int,
397
+ channel_dim: int = -1, # CAUTION: see documentation.
398
+ eps: float = 0.25,
399
+ learn_eps: bool = True,
400
+ eps_min: float = -3.0,
401
+ eps_max: float = 3.0,
402
+ ) -> None:
403
+ super(BasicNorm, self).__init__()
404
+ self.num_channels = num_channels
405
+ self.channel_dim = channel_dim
406
+ if learn_eps:
407
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
408
+ else:
409
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
410
+ self.eps_min = eps_min
411
+ self.eps_max = eps_max
412
+
413
+ def forward(self, x: Tensor) -> Tensor:
414
+ assert x.shape[self.channel_dim] == self.num_channels
415
+ eps = self.eps
416
+ if self.training and random.random() < 0.25:
417
+ # with probability 0.25, in training mode, clamp eps between the min
418
+ # and max; this will encourage it to learn parameters within the
419
+ # allowed range by making parameters that are outside the allowed
420
+ # range noisy.
421
+
422
+ # gradients to allow the parameter to get back into the allowed region if it happens to exit it.
423
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
424
+ scales = (
425
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
426
+ ) ** -0.5
427
+ return x * scales
428
+
429
+
430
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
431
+ """
432
+ Behaves like a constructor of a modified version of nn.Linear
433
+ that gives an easy way to set the default initial parameter scale.
434
+ Args:
435
+ Accepts the standard args and kwargs that nn.Linear accepts
436
+ e.g. in_features, out_features, bias=False.
437
+ initial_scale: you can override this if you want to increase
438
+ or decrease the initial magnitude of the module's output
439
+ (affects the initialization of weight_scale and bias_scale).
440
+ Another option, if you want to do something like this, is
441
+ to re-initialize the parameters.
442
+ """
443
+ ans = nn.Linear(*args, **kwargs)
444
+ with torch.no_grad():
445
+ ans.weight[:] *= initial_scale
446
+ if ans.bias is not None:
447
+ torch.nn.init.uniform_(
448
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
449
+ )
450
+ return ans
451
+
452
+
453
+ def ScaledConv1d(
454
+ *args,
455
+ initial_scale: float = 1.0,
456
+ kernel_size: int = 3,
457
+ padding: str = "same",
458
+ **kwargs,
459
+ ) -> nn.Conv1d:
460
+ """
461
+ Behaves like a constructor of a modified version of nn.Conv1d
462
+ that gives an easy way to set the default initial parameter scale.
463
+ Args:
464
+ Accepts the standard args and kwargs that nn.Linear accepts
465
+ e.g. in_features, out_features, bias=False.
466
+ initial_scale: you can override this if you want to increase
467
+ or decrease the initial magnitude of the module's output
468
+ (affects the initialization of weight_scale and bias_scale).
469
+ Another option, if you want to do something like this, is
470
+ to re-initialize the parameters.
471
+ """
472
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
473
+ with torch.no_grad():
474
+ ans.weight[:] *= initial_scale
475
+ if ans.bias is not None:
476
+ torch.nn.init.uniform_(
477
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
478
+ )
479
+ return ans
480
+
481
+
482
+ def TransposeScaledConv1d(
483
+ *args,
484
+ initial_scale: float = 1.0,
485
+ kernel_size: int = 3,
486
+ padding: str = "same",
487
+ **kwargs,
488
+ ) -> nn.Sequential:
489
+ """
490
+ Transpose -> ScaledConv1d
491
+ """
492
+ return nn.Sequential(
493
+ Transpose(),
494
+ ScaledConv1d(
495
+ *args,
496
+ initial_scale=initial_scale,
497
+ kernel_size=kernel_size,
498
+ padding=padding,
499
+ **kwargs,
500
+ ),
501
+ )
502
+
503
+
504
+ def ScaledConv1dTranspose(
505
+ *args,
506
+ initial_scale: float = 1.0,
507
+ kernel_size: int = 3,
508
+ padding: str = "same",
509
+ **kwargs,
510
+ ) -> nn.Sequential:
511
+ """
512
+ Transpose -> ScaledConv1d
513
+ """
514
+ return nn.Sequential(
515
+ ScaledConv1d(
516
+ *args,
517
+ initial_scale=initial_scale,
518
+ kernel_size=kernel_size,
519
+ padding=padding,
520
+ **kwargs,
521
+ ),
522
+ Transpose(),
523
+ )
524
+
525
+
526
+ def TransposeConv1d(
527
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
528
+ ) -> nn.Sequential:
529
+ """
530
+ Transpose -> Conv1d
531
+ """
532
+ return nn.Sequential(
533
+ Transpose(),
534
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
535
+ )
536
+
537
+
538
+ def Conv1dTranspose(
539
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
540
+ ) -> nn.Sequential:
541
+ """
542
+ ScaledConv1d -> Transpose
543
+ """
544
+ return nn.Sequential(
545
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
546
+ Transpose(),
547
+ )
548
+
549
+
550
+ class SRLinear(nn.Linear):
551
+ """https://arxiv.org/abs/2303.06296
552
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
553
+ """
554
+
555
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
556
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
557
+ self.register_buffer(
558
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
559
+ )
560
+ with torch.no_grad():
561
+ sigma = self.get_sigma()
562
+ self.register_buffer("spectral_norm", sigma)
563
+ self.sigma = nn.Parameter(torch.ones(1))
564
+
565
+ def get_sigma(self):
566
+ with torch.no_grad():
567
+ u = self.u
568
+ v = self.weight.mv(u)
569
+ v = nn.functional.normalize(v, dim=0)
570
+ u = self.weight.T.mv(v)
571
+ u = nn.functional.normalize(u, dim=0)
572
+ self.u.data.copy_(u)
573
+ return torch.einsum("c,cd,d->", v, self.weight, u)
574
+
575
+ def get_weight(self):
576
+ sigma = self.get_sigma()
577
+ if self.training:
578
+ self.spectral_norm.data.copy_(sigma)
579
+ weight = (self.sigma / sigma) * self.weight
580
+ return weight
581
+
582
+ def forward(self, x):
583
+ return nn.functional.linear(x, self.get_weight(), self.bias)
584
+
585
+
586
+ class SRConv1d(SRLinear):
587
+ def __init__(
588
+ self,
589
+ in_features,
590
+ out_features,
591
+ kernel_size,
592
+ stride: int = 1,
593
+ padding: str = "same",
594
+ bias: bool = True,
595
+ **kwargs,
596
+ ):
597
+ in_features = in_features * kernel_size
598
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
599
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
600
+ self.kernel_size = kernel_size
601
+ self.stride = stride
602
+ self.padding = padding
603
+
604
+ def forward(self, x):
605
+ in_features = self.in_features // self.kernel_size
606
+ weight = self.get_weight().view(
607
+ self.out_features, in_features, self.kernel_size
608
+ )
609
+ return nn.functional.conv1d(
610
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
611
+ )
612
+
613
+
614
+ def TransposeSRConv1d(
615
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
616
+ ) -> nn.Sequential:
617
+ """
618
+ Transpose -> SRConv1d
619
+ """
620
+ return nn.Sequential(
621
+ Transpose(),
622
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
623
+ )
624
+
625
+
626
+ def SRConv1dTranspose(
627
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
628
+ ) -> nn.Sequential:
629
+ """
630
+ SRConv1d -> Transpose
631
+ """
632
+ return nn.Sequential(
633
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
634
+ Transpose(),
635
+ )
636
+
637
+
638
+ class ActivationBalancer(torch.nn.Module):
639
+ """
640
+ Modifies the backpropped derivatives of a function to try to encourage, for
641
+ each channel, that it is positive at least a proportion `threshold` of the
642
+ time. It does this by multiplying negative derivative values by up to
643
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
644
+ interpolated from 1 at the threshold to those extremal values when none
645
+ of the inputs are positive.
646
+ Args:
647
+ num_channels: the number of channels
648
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
649
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
650
+ min_positive: the minimum, per channel, of the proportion of the time
651
+ that (x > 0), below which we start to modify the derivatives.
652
+ max_positive: the maximum, per channel, of the proportion of the time
653
+ that (x > 0), above which we start to modify the derivatives.
654
+ max_factor: the maximum factor by which we modify the derivatives for
655
+ either the sign constraint or the magnitude constraint;
656
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
657
+ values in the range [0.98..1.02].
658
+ sign_gain_factor: determines the 'gain' with which we increase the
659
+ change in gradient once the constraints on min_positive and max_positive
660
+ are violated.
661
+ scale_gain_factor: determines the 'gain' with which we increase the
662
+ change in gradient once the constraints on min_abs and max_abs
663
+ are violated.
664
+ min_abs: the minimum average-absolute-value difference from the mean
665
+ value per channel, which we allow, before we start to modify
666
+ the derivatives to prevent this.
667
+ max_abs: the maximum average-absolute-value difference from the mean
668
+ value per channel, which we allow, before we start to modify
669
+ the derivatives to prevent this.
670
+ min_prob: determines the minimum probability with which we modify the
671
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
672
+ on each forward(). This is done randomly to prevent all layers
673
+ from doing it at the same time. Early in training we may use
674
+ higher probabilities than this; it will decay to this value.
675
+ """
676
+
677
+ def __init__(
678
+ self,
679
+ num_channels: int,
680
+ channel_dim: int,
681
+ min_positive: float = 0.05,
682
+ max_positive: float = 0.95,
683
+ max_factor: float = 0.04,
684
+ sign_gain_factor: float = 0.01,
685
+ scale_gain_factor: float = 0.02,
686
+ min_abs: float = 0.2,
687
+ max_abs: float = 100.0,
688
+ min_prob: float = 0.1,
689
+ ):
690
+ super(ActivationBalancer, self).__init__()
691
+ self.num_channels = num_channels
692
+ self.channel_dim = channel_dim
693
+ self.min_positive = min_positive
694
+ self.max_positive = max_positive
695
+ self.max_factor = max_factor
696
+ self.min_abs = min_abs
697
+ self.max_abs = max_abs
698
+ self.min_prob = min_prob
699
+ self.sign_gain_factor = sign_gain_factor
700
+ self.scale_gain_factor = scale_gain_factor
701
+
702
+ # count measures how many times the forward() function has been called.
703
+ # We occasionally sync this to a tensor called `count`, that exists to
704
+ # make sure it is synced to disk when we load and save the model.
705
+ self.cpu_count = 0
706
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
707
+
708
+ def forward(self, x: Tensor) -> Tensor:
709
+ if (
710
+ torch.jit.is_scripting()
711
+ or not x.requires_grad
712
+ or torch.jit.is_tracing()
713
+ ):
714
+ return _no_op(x)
715
+
716
+ count = self.cpu_count
717
+ self.cpu_count += 1
718
+
719
+ if random.random() < 0.01:
720
+ # Occasionally sync self.cpu_count with self.count.
721
+ # count affects the decay of 'prob'. don't do this on every iter,
722
+ # because syncing with the GPU is slow.
723
+ self.cpu_count = max(self.cpu_count, self.count.item())
724
+ self.count.fill_(self.cpu_count)
725
+
726
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
727
+ # a floor at min_prob (==0.1, by default)
728
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
729
+
730
+ if random.random() < prob:
731
+ sign_gain_factor = 0.5
732
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
733
+ sign_factor = _compute_sign_factor(
734
+ x,
735
+ self.channel_dim,
736
+ self.min_positive,
737
+ self.max_positive,
738
+ gain_factor=self.sign_gain_factor / prob,
739
+ max_factor=self.max_factor,
740
+ )
741
+ else:
742
+ sign_factor = None
743
+
744
+ scale_factor = _compute_scale_factor(
745
+ x.detach(),
746
+ self.channel_dim,
747
+ min_abs=self.min_abs,
748
+ max_abs=self.max_abs,
749
+ gain_factor=self.scale_gain_factor / prob,
750
+ max_factor=self.max_factor,
751
+ )
752
+ return ActivationBalancerFunction.apply(
753
+ x,
754
+ scale_factor,
755
+ sign_factor,
756
+ self.channel_dim,
757
+ )
758
+ else:
759
+ return _no_op(x)
760
+
761
+
762
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
763
+ """
764
+ Returns x unmodified, but in backprop will put a penalty for the excess of
765
+ the absolute values of elements of x over the limit "limit". E.g. if
766
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
767
+ Caution: the value of this penalty will be affected by grad scaling used
768
+ in automatic mixed precision training. For this reasons we use this,
769
+ it shouldn't really matter, or may even be helpful; we just use this
770
+ to disallow really implausible values of scores to be given to softmax.
771
+ """
772
+ x_sign = x.sign()
773
+ over_limit = (x.abs() - limit) > 0
774
+ # The following is a memory efficient way to penalize the absolute values of
775
+ # x that's over the limit. (The memory efficiency comes when you think
776
+ # about which items torch needs to cache for the autograd, and which ones it
777
+ # can throw away). The numerical value of aux_loss as computed here will
778
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
779
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
780
+ # limit).relu().
781
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
782
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
783
+ # sum() due to how with_loss() works.
784
+ x = with_loss(x, aux_loss)
785
+ # you must use x for something, or this will be ineffective.
786
+ return x
787
+
788
+
789
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
790
+ if x.ndim == 2:
791
+ return x.diag()
792
+ else:
793
+ (batch, dim, dim) = x.shape
794
+ x = x.reshape(batch, dim * dim)
795
+ x = x[:, :: dim + 1]
796
+ assert x.shape == (batch, dim)
797
+ return x
798
+
799
+
800
+ def _whitening_metric(x: Tensor, num_groups: int):
801
+ """
802
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
803
+ of the centered feature covariance are the same within each group's covariance matrix
804
+ and also between groups.
805
+ Args:
806
+ x: a Tensor of shape (*, num_channels)
807
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
808
+ Returns:
809
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
810
+ greater than 1.0 otherwise.
811
+ """
812
+ assert x.dtype != torch.float16
813
+ x = x.reshape(-1, x.shape[-1])
814
+ (num_frames, num_channels) = x.shape
815
+ assert num_channels % num_groups == 0
816
+ channels_per_group = num_channels // num_groups
817
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
818
+ # x now has shape (num_groups, num_frames, channels_per_group)
819
+ # subtract the mean so we use the centered, not uncentered, covariance.
820
+ # My experience has been that when we "mess with the gradients" like this,
821
+ # it's better not do anything that tries to move the mean around, because
822
+ # that can easily cause instability.
823
+ x = x - x.mean(dim=1, keepdim=True)
824
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
825
+ x_covar = torch.matmul(x.transpose(1, 2), x)
826
+ x_covar_mean_diag = _diag(x_covar).mean()
827
+ # the following expression is what we'd get if we took the matrix product
828
+ # of each covariance and measured the mean of its trace, i.e.
829
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
830
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
831
+ num_groups * channels_per_group
832
+ )
833
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
834
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
835
+ return metric
836
+
837
+
838
+ class WhiteningPenaltyFunction(torch.autograd.Function):
839
+ @staticmethod
840
+ def forward(
841
+ ctx,
842
+ x: Tensor,
843
+ num_groups: int,
844
+ whitening_limit: float,
845
+ grad_scale: float,
846
+ ) -> Tensor:
847
+ ctx.save_for_backward(x)
848
+ ctx.num_groups = num_groups
849
+ ctx.whitening_limit = whitening_limit
850
+ ctx.grad_scale = grad_scale
851
+ return x
852
+
853
+ @staticmethod
854
+ def backward(ctx, x_grad: Tensor):
855
+ (x_orig,) = ctx.saved_tensors
856
+ with torch.enable_grad():
857
+ with torch.cuda.amp.autocast(enabled=False):
858
+ x_detached = x_orig.to(torch.float32).detach()
859
+ x_detached.requires_grad = True
860
+
861
+ metric = _whitening_metric(x_detached, ctx.num_groups)
862
+
863
+ if random.random() < 0.005 or __name__ == "__main__":
864
+ logging.info(
865
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
866
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
867
+ )
868
+
869
+ (metric - ctx.whitening_limit).relu().backward()
870
+ penalty_grad = x_detached.grad
871
+ scale = ctx.grad_scale * (
872
+ x_grad.to(torch.float32).norm()
873
+ / (penalty_grad.norm() + 1.0e-20)
874
+ )
875
+ penalty_grad = penalty_grad * scale
876
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
877
+
878
+
879
+ class Whiten(nn.Module):
880
+ def __init__(
881
+ self,
882
+ num_groups: int,
883
+ whitening_limit: float,
884
+ prob: Union[float, Tuple[float, float]],
885
+ grad_scale: float,
886
+ ):
887
+ """
888
+ Args:
889
+ num_groups: the number of groups to divide the channel dim into before
890
+ whitening. We will attempt to make the feature covariance
891
+ within each group, after mean subtraction, as "white" as possible,
892
+ while having the same trace across all groups.
893
+ whitening_limit: a value greater than 1.0, that dictates how much
894
+ freedom we have to violate the constraints. 1.0 would mean perfectly
895
+ white, with exactly the same trace across groups; larger values
896
+ give more freedom. E.g. 2.0.
897
+ prob: the probability with which we apply the gradient modification
898
+ (also affects the grad scale). May be supplied as a float,
899
+ or as a pair (min_prob, max_prob)
900
+ grad_scale: determines the scale on the gradient term from this object,
901
+ relative to the rest of the gradient on the attention weights.
902
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
903
+ """
904
+ super(Whiten, self).__init__()
905
+ assert num_groups >= 1
906
+ assert whitening_limit >= 1
907
+ assert grad_scale >= 0
908
+ self.num_groups = num_groups
909
+ self.whitening_limit = whitening_limit
910
+ if isinstance(prob, float):
911
+ assert 0 < prob <= 1
912
+ self.prob = prob
913
+ else:
914
+ (self.min_prob, self.max_prob) = prob
915
+ assert 0 < self.min_prob < self.max_prob <= 1
916
+ self.prob = self.max_prob
917
+
918
+ self.grad_scale = grad_scale
919
+
920
+ def forward(self, x: Tensor) -> Tensor:
921
+ """
922
+ In the forward pass, this function just returns the input unmodified.
923
+ In the backward pass, it will modify the gradients to ensure that the
924
+ distribution in each group has close to (lambda times I) as the covariance
925
+ after mean subtraction, with the same lambda across groups.
926
+ For whitening_limit > 1, there will be more freedom to violate this
927
+ constraint.
928
+ Args:
929
+ x: the input of shape (*, num_channels)
930
+ Returns:
931
+ x, unmodified. You should make sure
932
+ you use the returned value, or the graph will be freed
933
+ and nothing will happen in backprop.
934
+ """
935
+ if (
936
+ not x.requires_grad
937
+ or random.random() > self.prob
938
+ or self.grad_scale == 0
939
+ ):
940
+ return _no_op(x)
941
+ else:
942
+ if hasattr(self, "min_prob") and random.random() < 0.25:
943
+ # occasionally switch between min_prob and max_prob, based on whether
944
+ # we are above or below the threshold.
945
+ if (
946
+ _whitening_metric(x.to(torch.float32), self.num_groups)
947
+ > self.whitening_limit
948
+ ):
949
+ # there would be a change to the grad.
950
+ self.prob = self.max_prob
951
+ else:
952
+ self.prob = self.min_prob
953
+
954
+ return WhiteningPenaltyFunction.apply(
955
+ x, self.num_groups, self.whitening_limit, self.grad_scale
956
+ )
957
+
958
+
959
+ class WithLoss(torch.autograd.Function):
960
+ @staticmethod
961
+ def forward(ctx, x: Tensor, y: Tensor):
962
+ ctx.y_shape = y.shape
963
+ return x
964
+
965
+ @staticmethod
966
+ def backward(ctx, ans_grad: Tensor):
967
+ return ans_grad, torch.ones(
968
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
969
+ )
970
+
971
+
972
+ def with_loss(x, y):
973
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
974
+ return x
975
+ # returns x but adds y.sum() to the loss function.
976
+ return WithLoss.apply(x, y)
977
+
978
+
979
+ def _no_op(x: Tensor) -> Tensor:
980
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
981
+ return x
982
+ else:
983
+ # a no-op function that will have a node in the autograd graph,
984
+ # to avoid certain bugs relating to backward hooks
985
+ return x.chunk(1, dim=-1)[0]
986
+
987
+
988
+ class Identity(torch.nn.Module):
989
+ def __init__(self):
990
+ super(Identity, self).__init__()
991
+
992
+ def forward(self, x):
993
+ return _no_op(x)
994
+
995
+
996
+ class MaxEig(torch.nn.Module):
997
+ """
998
+ Modifies the backpropped derivatives of a function to try to discourage
999
+ that any given direction in activation space accounts for more than
1000
+ a specified proportion of the covariance (e.g. 0.2).
1001
+ Args:
1002
+ num_channels: the number of channels
1003
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1004
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1005
+ max_var_per_eig: the maximum proportion of the variance of the
1006
+ features/channels, after mean subtraction, that can come from
1007
+ any given eigenvalue.
1008
+ min_prob: the minimum probability with which we apply this during any invocation
1009
+ of forward(), assuming last time we applied the constraint it was
1010
+ not active; supplied for speed.
1011
+ scale: determines the scale with which we modify the gradients, relative
1012
+ to the existing / unmodified gradients
1013
+ """
1014
+
1015
+ def __init__(
1016
+ self,
1017
+ num_channels: int,
1018
+ channel_dim: int,
1019
+ max_var_per_eig: float = 0.2,
1020
+ min_prob: float = 0.01,
1021
+ scale: float = 0.01,
1022
+ ):
1023
+ super(MaxEig, self).__init__()
1024
+ self.num_channels = num_channels
1025
+ self.channel_dim = channel_dim
1026
+ self.scale = scale
1027
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1028
+ self.max_var_per_eig = max_var_per_eig
1029
+
1030
+ # we figure out the dominant direction using the power method: starting with
1031
+ # a random vector, keep multiplying by the covariance and renormalizing.
1032
+ with torch.no_grad():
1033
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1034
+ # random parameters unchanged for comparison
1035
+ direction = torch.arange(num_channels).to(torch.float)
1036
+ direction = direction / direction.norm()
1037
+ self.register_buffer("max_eig_direction", direction)
1038
+
1039
+ self.min_prob = min_prob
1040
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1041
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1042
+ # active.
1043
+ self.cur_prob = 1.0
1044
+
1045
+ def forward(self, x: Tensor) -> Tensor:
1046
+ if (
1047
+ torch.jit.is_scripting()
1048
+ or self.max_var_per_eig <= 0
1049
+ or random.random() > self.cur_prob
1050
+ or torch.jit.is_tracing()
1051
+ ):
1052
+ return _no_op(x)
1053
+
1054
+ with torch.cuda.amp.autocast(enabled=False):
1055
+ eps = 1.0e-20
1056
+ orig_x = x
1057
+ x = x.to(torch.float32)
1058
+ with torch.no_grad():
1059
+ x = x.transpose(self.channel_dim, -1).reshape(
1060
+ -1, self.num_channels
1061
+ )
1062
+ x = x - x.mean(dim=0)
1063
+ new_direction, coeffs = self._find_direction_coeffs(
1064
+ x, self.max_eig_direction
1065
+ )
1066
+ x_var = (x ** 2).mean()
1067
+ x_residual = x - coeffs * new_direction
1068
+ x_residual_var = (x_residual ** 2).mean()
1069
+
1070
+ # `variance_proportion` is the proportion of the variance accounted for
1071
+ # by the top eigen-direction.
1072
+ variance_proportion = (x_var - x_residual_var) / (
1073
+ x_var + 1.0e-20
1074
+ )
1075
+
1076
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1077
+ self._set_direction(
1078
+ 0.1 * self.max_eig_direction + new_direction
1079
+ )
1080
+
1081
+ if random.random() < 0.01 or __name__ == "__main__":
1082
+ logging.info(
1083
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1084
+ )
1085
+
1086
+ if variance_proportion >= self.max_var_per_eig:
1087
+ # The constraint is active. Note, we should quite rarely
1088
+ # reach here, only near the beginning of training if we are
1089
+ # starting to diverge, should this constraint be active.
1090
+ cur_prob = self.cur_prob
1091
+ self.cur_prob = (
1092
+ 1.0 # next time, do the update with probability 1.0.
1093
+ )
1094
+ return MaxEigLimiterFunction.apply(
1095
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1096
+ )
1097
+ else:
1098
+ # let self.cur_prob exponentially approach self.min_prob, as
1099
+ # long as the constraint is inactive.
1100
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1101
+ return orig_x
1102
+
1103
+ def _set_direction(self, direction: Tensor):
1104
+ """
1105
+ Sets self.max_eig_direction to a normalized version of `direction`
1106
+ """
1107
+ direction = direction.detach()
1108
+ direction = direction / direction.norm()
1109
+ direction_sum = direction.sum().item()
1110
+ if direction_sum - direction_sum == 0: # no inf/nan
1111
+ self.max_eig_direction[:] = direction
1112
+ else:
1113
+ logging.info(
1114
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1115
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1116
+ )
1117
+
1118
+ def _find_direction_coeffs(
1119
+ self, x: Tensor, prev_direction: Tensor
1120
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1121
+ """
1122
+ Figure out (an approximation to) the proportion of the variance of a set of
1123
+ feature vectors that can be attributed to the top eigen-direction.
1124
+ Args:
1125
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1126
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1127
+ of the top eigen-direction, or a random direction if this is the first
1128
+ iteration. Does not have to be normalized, but should be nonzero.
1129
+ Returns: (cur_direction, coeffs), where:
1130
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1131
+ estimate of the top eigen-direction.
1132
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1133
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1134
+ """
1135
+ (num_frames, num_channels) = x.shape
1136
+ assert num_channels > 1 and num_frames > 1
1137
+ assert prev_direction.shape == (num_channels,)
1138
+ # `coeffs` are the coefficients of `prev_direction` in x.
1139
+ # actually represent the coeffs up to a constant positive factor.
1140
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1141
+ cur_direction = (x * coeffs).sum(dim=0) / (
1142
+ (coeffs ** 2).sum() + 1.0e-20
1143
+ )
1144
+ return cur_direction, coeffs
1145
+
1146
+
1147
+ class DoubleSwishFunction(torch.autograd.Function):
1148
+ """
1149
+ double_swish(x) = x * torch.sigmoid(x-1)
1150
+ This is a definition, originally motivated by its close numerical
1151
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1152
+ Memory-efficient derivative computation:
1153
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1154
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1155
+ Now, s'(x) = s(x) * (1-s(x)).
1156
+ double_swish'(x) = x * s'(x) + s(x).
1157
+ = x * s(x) * (1-s(x)) + s(x).
1158
+ = double_swish(x) * (1-s(x)) + s(x)
1159
+ ... so we just need to remember s(x) but not x itself.
1160
+ """
1161
+
1162
+ @staticmethod
1163
+ def forward(ctx, x: Tensor) -> Tensor:
1164
+ requires_grad = x.requires_grad
1165
+ x_dtype = x.dtype
1166
+ if x.dtype == torch.float16:
1167
+ x = x.to(torch.float32)
1168
+
1169
+ s = torch.sigmoid(x - 1.0)
1170
+ y = x * s
1171
+
1172
+ if requires_grad:
1173
+ deriv = y * (1 - s) + s
1174
+ # notes on derivative of x * sigmoid(x - 1):
1175
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1176
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1177
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1178
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1179
+ # floors), should be expectation-preserving.
1180
+ floor = -0.043637
1181
+ ceil = 1.2
1182
+ d_scaled = (deriv - floor) * (
1183
+ 255.0 / (ceil - floor)
1184
+ ) + torch.rand_like(deriv)
1185
+ if __name__ == "__main__":
1186
+ # for self-testing only.
1187
+ assert d_scaled.min() >= 0.0
1188
+ assert d_scaled.max() < 256.0
1189
+ d_int = d_scaled.to(torch.uint8)
1190
+ ctx.save_for_backward(d_int)
1191
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1192
+ y = y.to(torch.float16)
1193
+ return y
1194
+
1195
+ @staticmethod
1196
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1197
+ (d,) = ctx.saved_tensors
1198
+ # the same constants as used in forward pass.
1199
+ floor = -0.043637
1200
+ ceil = 1.2
1201
+ d = d * ((ceil - floor) / 255.0) + floor
1202
+ return y_grad * d
1203
+
1204
+
1205
+ class DoubleSwish(torch.nn.Module):
1206
+ def forward(self, x: Tensor) -> Tensor:
1207
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1208
+ that we approximate closely with x * sigmoid(x-1).
1209
+ """
1210
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1211
+ return x * torch.sigmoid(x - 1.0)
1212
+ return DoubleSwishFunction.apply(x)
1213
+
1214
+
1215
+ def BalancedDoubleSwish(
1216
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1217
+ ) -> nn.Sequential:
1218
+ """
1219
+ ActivationBalancer -> DoubleSwish
1220
+ """
1221
+ balancer = ActivationBalancer(
1222
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1223
+ )
1224
+ return nn.Sequential(
1225
+ balancer,
1226
+ DoubleSwish(),
1227
+ )
1228
+
1229
+
1230
+ def _test_max_eig():
1231
+ for proportion in [0.1, 0.5, 10.0]:
1232
+ logging.info(f"proportion = {proportion}")
1233
+ x = torch.randn(100, 128)
1234
+ direction = torch.randn(128)
1235
+ coeffs = torch.randn(100, 1)
1236
+ x += proportion * direction * coeffs
1237
+
1238
+ x.requires_grad = True
1239
+
1240
+ num_channels = 128
1241
+ m = MaxEig(
1242
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1243
+ ) # grad_scale
1244
+
1245
+ for _ in range(4):
1246
+ y = m(x)
1247
+
1248
+ y_grad = torch.randn_like(x)
1249
+ y.backward(gradient=y_grad)
1250
+
1251
+ if proportion < 0.2:
1252
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1253
+ elif proportion > 1.0:
1254
+ assert not torch.allclose(x.grad, y_grad)
1255
+
1256
+
1257
+ def _test_whiten():
1258
+ for proportion in [0.1, 0.5, 10.0]:
1259
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1260
+ x = torch.randn(100, 128)
1261
+ direction = torch.randn(128)
1262
+ coeffs = torch.randn(100, 1)
1263
+ x += proportion * direction * coeffs
1264
+
1265
+ x.requires_grad = True
1266
+
1267
+ num_channels = 128
1268
+ m = Whiten(
1269
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1270
+ ) # grad_scale
1271
+
1272
+ for _ in range(4):
1273
+ y = m(x)
1274
+
1275
+ y_grad = torch.randn_like(x)
1276
+ y.backward(gradient=y_grad)
1277
+
1278
+ if proportion < 0.2:
1279
+ assert torch.allclose(x.grad, y_grad)
1280
+ elif proportion > 1.0:
1281
+ assert not torch.allclose(x.grad, y_grad)
1282
+
1283
+
1284
+ def _test_activation_balancer_sign():
1285
+ probs = torch.arange(0, 1, 0.01)
1286
+ N = 1000
1287
+ x = 1.0 * (
1288
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1289
+ )
1290
+ x = x.detach()
1291
+ x.requires_grad = True
1292
+ m = ActivationBalancer(
1293
+ probs.numel(),
1294
+ channel_dim=0,
1295
+ min_positive=0.05,
1296
+ max_positive=0.95,
1297
+ max_factor=0.2,
1298
+ min_abs=0.0,
1299
+ )
1300
+
1301
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1302
+
1303
+ y = m(x)
1304
+ y.backward(gradient=y_grad)
1305
+ print("_test_activation_balancer_sign: x = ", x)
1306
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1307
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1308
+
1309
+
1310
+ def _test_activation_balancer_magnitude():
1311
+ magnitudes = torch.arange(0, 1, 0.01)
1312
+ N = 1000
1313
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1314
+ -1
1315
+ )
1316
+ x = x.detach()
1317
+ x.requires_grad = True
1318
+ m = ActivationBalancer(
1319
+ magnitudes.numel(),
1320
+ channel_dim=0,
1321
+ min_positive=0.0,
1322
+ max_positive=1.0,
1323
+ max_factor=0.2,
1324
+ min_abs=0.2,
1325
+ max_abs=0.8,
1326
+ min_prob=1.0,
1327
+ )
1328
+
1329
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1330
+
1331
+ y = m(x)
1332
+ y.backward(gradient=y_grad)
1333
+ print("_test_activation_balancer_magnitude: x = ", x)
1334
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1335
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1336
+
1337
+
1338
+ def _test_basic_norm():
1339
+ num_channels = 128
1340
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1341
+
1342
+ x = torch.randn(500, num_channels)
1343
+
1344
+ y = m(x)
1345
+
1346
+ assert y.shape == x.shape
1347
+ x_rms = (x ** 2).mean().sqrt()
1348
+ y_rms = (y ** 2).mean().sqrt()
1349
+ print("x rms = ", x_rms)
1350
+ print("y rms = ", y_rms)
1351
+ assert y_rms < x_rms
1352
+ assert y_rms > 0.5 * x_rms
1353
+
1354
+
1355
+ def _test_double_swish_deriv():
1356
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1357
+ x.requires_grad = True
1358
+ m = DoubleSwish()
1359
+
1360
+ tol = (1.2 - (-0.043637)) / 255.0
1361
+ torch.autograd.gradcheck(m, x, atol=tol)
1362
+
1363
+ # for self-test.
1364
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1365
+ x.requires_grad = True
1366
+ y = m(x)
1367
+
1368
+
1369
+ def _test_softmax():
1370
+ a = torch.randn(2, 10, dtype=torch.float64)
1371
+ b = a.clone()
1372
+ a.requires_grad = True
1373
+ b.requires_grad = True
1374
+ a.softmax(dim=1)[:, 0].sum().backward()
1375
+ print("a grad = ", a.grad)
1376
+ softmax(b, dim=1)[:, 0].sum().backward()
1377
+ print("b grad = ", b.grad)
1378
+ assert torch.allclose(a.grad, b.grad)
1379
+
1380
+
1381
+ if __name__ == "__main__":
1382
+ logging.getLogger().setLevel(logging.INFO)
1383
+ torch.set_num_threads(1)
1384
+ torch.set_num_interop_threads(1)
1385
+ _test_softmax()
1386
+ _test_whiten()
1387
+ _test_max_eig()
1388
+ _test_activation_balancer_sign()
1389
+ _test_activation_balancer_magnitude()
1390
+ _test_basic_norm()
1391
+ _test_double_swish_deriv()
src/model/modules/siglip.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class SiglipVisionConfig:
7
+
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ num_channels=3,
15
+ image_size=224,
16
+ patch_size=16,
17
+ layer_norm_eps=1e-6,
18
+ attention_dropout=0.0,
19
+ num_image_tokens: int = None,
20
+ **kwargs,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.hidden_size = hidden_size
25
+ self.intermediate_size = intermediate_size
26
+ self.num_hidden_layers = num_hidden_layers
27
+ self.num_attention_heads = num_attention_heads
28
+ self.num_channels = num_channels
29
+ self.patch_size = patch_size
30
+ self.image_size = image_size
31
+ self.attention_dropout = attention_dropout
32
+ self.layer_norm_eps = layer_norm_eps
33
+ self.num_image_tokens = num_image_tokens
34
+
35
+
36
+ class SiglipVisionEmbeddings(nn.Module):
37
+ def __init__(self, config: SiglipVisionConfig):
38
+ super().__init__()
39
+ self.config = config
40
+ self.embed_dim = config.hidden_size
41
+ self.image_size = config.image_size
42
+ self.patch_size = config.patch_size
43
+
44
+ self.patch_embedding = nn.Conv2d(
45
+ in_channels=config.num_channels,
46
+ out_channels=self.embed_dim,
47
+ kernel_size=self.patch_size,
48
+ stride=self.patch_size,
49
+ padding="valid", # This indicates no padding is added
50
+ )
51
+
52
+ self.num_patches = (self.image_size // self.patch_size) ** 2
53
+ self.num_positions = self.num_patches
54
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
55
+ self.register_buffer(
56
+ "position_ids",
57
+ torch.arange(self.num_positions).expand((1, -1)),
58
+ persistent=False,
59
+ )
60
+
61
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
62
+ _, _, height, width = (
63
+ pixel_values.shape
64
+ ) # [Batch_Size, Channels, Height, Width]
65
+ # Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
66
+ # The output of the convolution will have shape [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
67
+ # where Num_Patches_H = height // patch_size and Num_Patches_W = width // patch_size
68
+ patch_embeds = self.patch_embedding(pixel_values)
69
+ # [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
70
+ # where Num_Patches = Num_Patches_H * Num_Patches_W
71
+ embeddings = patch_embeds.flatten(2)
72
+ # [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
73
+ embeddings = embeddings.transpose(1, 2)
74
+ # Add position embeddings to each patch. Each positional encoding is a vector of size [Embed_Dim]
75
+ embeddings = embeddings + self.position_embedding(self.position_ids)
76
+ # [Batch_Size, Num_Patches, Embed_Dim]
77
+ return embeddings
78
+
79
+
80
+ class SiglipAttention(nn.Module):
81
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
82
+
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.config = config
86
+ self.embed_dim = config.hidden_size
87
+ self.num_heads = config.num_attention_heads
88
+ self.head_dim = self.embed_dim // self.num_heads
89
+ self.scale = self.head_dim**-0.5 # Equivalent to 1 / sqrt(self.head_dim)
90
+ self.dropout = config.attention_dropout
91
+
92
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
93
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
94
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
95
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states: torch.Tensor,
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
101
+
102
+ # hidden_states: [Batch_Size, Num_Patches, Embed_Dim]
103
+ batch_size, seq_len, _ = hidden_states.size()
104
+ # query_states: [Batch_Size, Num_Patches, Embed_Dim]
105
+ query_states = self.q_proj(hidden_states)
106
+ # key_states: [Batch_Size, Num_Patches, Embed_Dim]
107
+ key_states = self.k_proj(hidden_states)
108
+ # value_states: [Batch_Size, Num_Patches, Embed_Dim]
109
+ value_states = self.v_proj(hidden_states)
110
+ # query_states: [Batch_Size, Num_Heads, Num_Patches, Head_Dim]
111
+ query_states = query_states.view(
112
+ batch_size, seq_len, self.num_heads, self.head_dim
113
+ ).transpose(1, 2)
114
+
115
+ key_states = key_states.view(
116
+ batch_size, seq_len, self.num_heads, self.head_dim
117
+ ).transpose(1, 2)
118
+
119
+ value_states = value_states.view(
120
+ batch_size, seq_len, self.num_heads, self.head_dim
121
+ ).transpose(1, 2)
122
+ # Calculate the attention using the formula Q * K^T / sqrt(d_k). attn_weights: [Batch_Size, Num_Heads, Num_Patches, Num_Patches]
123
+ attn_weights = (
124
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
125
+ )
126
+
127
+ if attn_weights.size() != (batch_size, self.num_heads, seq_len, seq_len):
128
+ raise ValueError(
129
+ f"Attention weights should be of size {(batch_size, self.num_heads, seq_len, seq_len)}, but is"
130
+ f" {attn_weights.size()}"
131
+ )
132
+
133
+ # Apply the softmax row-wise. attn_weights: [Batch_Size, Num_Heads, Num_Patches, Num_Patches]
134
+ attn_weights = nn.functional.softmax(
135
+ attn_weights, dim=-1, dtype=torch.float32
136
+ ).to(query_states.dtype)
137
+ # Apply dropout only during training
138
+ attn_weights = nn.functional.dropout(
139
+ attn_weights, p=self.dropout, training=self.training
140
+ )
141
+ # Multiply the attention weights by the value states. attn_output: [Batch_Size, Num_Heads, Num_Patches, Head_Dim]
142
+ attn_output = torch.matmul(attn_weights, value_states)
143
+
144
+ if attn_output.size() != (batch_size, self.num_heads, seq_len, self.head_dim):
145
+ raise ValueError(
146
+ f"`attn_output` should be of size {(batch_size, self.num_heads, seq_len, self.head_dim)}, but is"
147
+ f" {attn_output.size()}"
148
+ )
149
+ # [Batch_Size, Num_Heads, Num_Patches, Head_Dim] -> [Batch_Size, Num_Patches, Num_Heads, Head_Dim]
150
+ attn_output = attn_output.transpose(1, 2).contiguous()
151
+ # [Batch_Size, Num_Patches, Num_Heads, Head_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
152
+ attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
153
+ # [Batch_Size, Num_Patches, Embed_Dim]
154
+ attn_output = self.out_proj(attn_output)
155
+
156
+ return attn_output, attn_weights
157
+
158
+
159
+ class SiglipMLP(nn.Module):
160
+ def __init__(self, config):
161
+ super().__init__()
162
+ self.config = config
163
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
164
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
165
+
166
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
167
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Intermediate_Size]
168
+ hidden_states = self.fc1(hidden_states)
169
+ # hidden_states: [Batch_Size, Num_Patches, Intermediate_Size]
170
+ hidden_states = nn.functional.gelu(hidden_states, approximate="tanh")
171
+ # [Batch_Size, Num_Patches, Intermediate_Size] -> [Batch_Size, Num_Patches, Embed_Dim]
172
+ hidden_states = self.fc2(hidden_states)
173
+
174
+ return hidden_states
175
+
176
+
177
+ class SiglipEncoderLayer(nn.Module):
178
+ def __init__(self, config: SiglipVisionConfig):
179
+ super().__init__()
180
+ self.embed_dim = config.hidden_size
181
+ self.self_attn = SiglipAttention(config)
182
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
183
+ self.mlp = SiglipMLP(config)
184
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
185
+
186
+ # Ignore copy
187
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
+ # residual: [Batch_Size, Num_Patches, Embed_Dim]
189
+ residual = hidden_states
190
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
191
+ hidden_states = self.layer_norm1(hidden_states)
192
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
193
+ hidden_states, _ = self.self_attn(hidden_states=hidden_states)
194
+ # [Batch_Size, Num_Patches, Embed_Dim]
195
+ hidden_states = residual + hidden_states
196
+ # residual: [Batch_Size, Num_Patches, Embed_Dim]
197
+ residual = hidden_states
198
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
199
+ hidden_states = self.layer_norm2(hidden_states)
200
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
201
+ hidden_states = self.mlp(hidden_states)
202
+ # [Batch_Size, Num_Patches, Embed_Dim]
203
+ hidden_states = residual + hidden_states
204
+
205
+ return hidden_states
206
+
207
+
208
+ class SiglipEncoder(nn.Module):
209
+ def __init__(self, config: SiglipVisionConfig):
210
+ super().__init__()
211
+ self.config = config
212
+ self.layers = nn.ModuleList(
213
+ [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
214
+ )
215
+
216
+ # Ignore copy
217
+ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
218
+ # inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
219
+ hidden_states = inputs_embeds
220
+
221
+ for encoder_layer in self.layers:
222
+ # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
223
+ hidden_states = encoder_layer(hidden_states)
224
+
225
+ return hidden_states
226
+
227
+
228
+ class SiglipVisionTransformer(nn.Module):
229
+ def __init__(self, config: SiglipVisionConfig):
230
+ super().__init__()
231
+ self.config = config
232
+ embed_dim = config.hidden_size
233
+
234
+ self.embeddings = SiglipVisionEmbeddings(config)
235
+ self.encoder = SiglipEncoder(config)
236
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
237
+
238
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
239
+ # pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
240
+ hidden_states = self.embeddings(pixel_values)
241
+
242
+ last_hidden_state = self.encoder(inputs_embeds=hidden_states)
243
+
244
+ last_hidden_state = self.post_layernorm(last_hidden_state)
245
+
246
+ return last_hidden_state
247
+
248
+
249
+ class SiglipVisionModel(nn.Module):
250
+
251
+ def __init__(self, config: SiglipVisionConfig):
252
+ super().__init__()
253
+ self.config = config
254
+ self.vision_model = SiglipVisionTransformer(config)
255
+
256
+ def forward(self, pixel_values) -> Tuple:
257
+ # [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
258
+ return self.vision_model(pixel_values=pixel_values)
src/model/modules/tokenizer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ # from lhotse.features import FeatureExtractor
24
+ # from lhotse.utils import Seconds, compute_num_frames
25
+ from phonemizer.backend import EspeakBackend
26
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
27
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
28
+ from phonemizer.punctuation import Punctuation
29
+ from phonemizer.separator import Separator
30
+
31
+
32
+
33
+ class TextTokenizer:
34
+ """Phonemize Text."""
35
+
36
+ def __init__(
37
+ self,
38
+ language="en-us",
39
+ backend="espeak",
40
+ separator=Separator(word="_", syllable="-", phone="|"),
41
+ preserve_punctuation=True,
42
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
43
+ with_stress: bool = False,
44
+ tie: Union[bool, str] = False,
45
+ language_switch: LanguageSwitch = "keep-flags",
46
+ words_mismatch: WordMismatch = "ignore",
47
+ ) -> None:
48
+ phonemizer = EspeakBackend(
49
+ language,
50
+ punctuation_marks=punctuation_marks,
51
+ preserve_punctuation=preserve_punctuation,
52
+ with_stress=with_stress,
53
+ tie=tie,
54
+ language_switch=language_switch,
55
+ words_mismatch=words_mismatch,
56
+ )
57
+
58
+ self.backend = phonemizer
59
+ self.separator = separator
60
+
61
+ def to_list(self, phonemized: str) -> List[str]:
62
+ fields = []
63
+ for word in phonemized.split(self.separator.word):
64
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
65
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
66
+ fields.extend(
67
+ [p for p in pp if p != self.separator.phone]
68
+ + [self.separator.word]
69
+ )
70
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
71
+ self.separator.phone
72
+ )
73
+ return fields[:-1]
74
+
75
+ def __call__(self, text, strip=True) -> List[List[str]]:
76
+ if isinstance(text, str):
77
+ text = [text]
78
+
79
+ phonemized = self.backend.phonemize(
80
+ text, separator=self.separator, strip=strip, njobs=1
81
+ )
82
+ return [self.to_list(p) for p in phonemized]
83
+
84
+
85
+ def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
86
+ phonemes = tokenizer([text.strip()])
87
+ return phonemes[0] # k2symbols
88
+
89
+ def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
90
+ assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
91
+ if target_channels == 1:
92
+ wav = wav.mean(0, keepdim=True)
93
+ elif target_channels == 2:
94
+ *shape, _, length = wav.shape
95
+ wav = wav.expand(*shape, target_channels, length)
96
+ elif wav.shape[0] == 1:
97
+ wav = wav.expand(target_channels, -1)
98
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
99
+ return wav
100
+
101
+ class AudioTokenizer:
102
+ """EnCodec audio."""
103
+
104
+ def __init__(
105
+ self,
106
+ device: Any = None,
107
+ signature = None
108
+ ) -> None:
109
+ from audiocraft.solvers import CompressionSolver
110
+ model = CompressionSolver.model_from_checkpoint(signature)
111
+ self.sample_rate = model.sample_rate
112
+ self.channels = model.channels
113
+
114
+ if not device:
115
+ device = torch.device("cpu")
116
+ if torch.cuda.is_available():
117
+ device = torch.device("cuda:0")
118
+
119
+ self._device = device
120
+
121
+ self.codec = model.to(device)
122
+
123
+ @property
124
+ def device(self):
125
+ return self._device
126
+
127
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
128
+ codes = self.codec.encode(wav.to(self.device))
129
+ return [(codes[0], None)]
130
+
131
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
132
+ frames = frames[0][0] # [1,4,T]
133
+ return self.codec.decode(frames)
134
+
135
+
136
+
137
+ def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
138
+ # Load and pre-process the audio waveform
139
+ if offset != -1 and num_frames!=-1:
140
+ wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
141
+ else:
142
+ wav, sr = torchaudio.load(audio_path)
143
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
144
+ wav = wav.unsqueeze(0)
145
+
146
+ # Extract discrete codes from EnCodec
147
+ with torch.no_grad():
148
+ encoded_frames = tokenizer.encode(wav)
149
+ return encoded_frames
src/model/modules/transformer.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any, Callable, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.nn import functional as F
10
+
11
+ from .activation import MultiheadAttention
12
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
13
+ from .scaling import BasicNorm as _BasicNorm
14
+
15
+ _shape_t = Union[int, List[int], torch.Size]
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
20
+ normalized_shape: Tuple[int, ...]
21
+ eps: float
22
+ elementwise_affine: bool
23
+
24
+ def __init__(
25
+ self,
26
+ normalized_shape: _shape_t,
27
+ eps: float = 1e-5,
28
+ elementwise_affine: bool = True,
29
+ device=None,
30
+ dtype=None,
31
+ ) -> None:
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super(LayerNorm, self).__init__()
34
+ if isinstance(normalized_shape, numbers.Integral):
35
+ # mypy error: incompatible types in assignment
36
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
37
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
38
+ self.eps = eps
39
+ self.elementwise_affine = elementwise_affine
40
+ if self.elementwise_affine:
41
+ self.weight = nn.Parameter(
42
+ torch.empty(self.normalized_shape, **factory_kwargs)
43
+ )
44
+ self.bias = nn.Parameter(
45
+ torch.empty(self.normalized_shape, **factory_kwargs)
46
+ )
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(
74
+ input, self.normalized_shape, self.weight, self.bias, self.eps
75
+ )
76
+
77
+ def extra_repr(self) -> str:
78
+ return (
79
+ "{normalized_shape}, eps={eps}, "
80
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
81
+ )
82
+
83
+
84
+ class AdaptiveLayerNorm(nn.Module):
85
+ r"""Adaptive Layer Normalization"""
86
+
87
+ def __init__(self, d_model, norm) -> None:
88
+ super(AdaptiveLayerNorm, self).__init__()
89
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
90
+ self.norm = norm
91
+ self.d_model = d_model
92
+ self.eps = self.norm.eps
93
+
94
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
95
+ if isinstance(input, tuple):
96
+ input, embedding = input
97
+ weight, bias = torch.split(
98
+ self.project_layer(embedding),
99
+ split_size_or_sections=self.d_model,
100
+ dim=-1,
101
+ )
102
+ return (weight * self.norm(input) + bias, embedding)
103
+
104
+ weight, bias = torch.split(
105
+ self.project_layer(embedding),
106
+ split_size_or_sections=self.d_model,
107
+ dim=-1,
108
+ )
109
+ return weight * self.norm(input) + bias
110
+
111
+
112
+ class BasicNorm(_BasicNorm):
113
+ def __init__(
114
+ self,
115
+ d_model: int,
116
+ eps: float = 1e-5,
117
+ device=None,
118
+ dtype=None,
119
+ ):
120
+ super(BasicNorm, self).__init__(d_model, eps=eps)
121
+
122
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
123
+ if isinstance(input, tuple):
124
+ input, embedding = input
125
+ return (
126
+ super(BasicNorm, self).forward(input),
127
+ embedding,
128
+ )
129
+
130
+ assert embedding is None
131
+ return super(BasicNorm, self).forward(input)
132
+
133
+
134
+ class BalancedBasicNorm(nn.Module):
135
+ def __init__(
136
+ self,
137
+ d_model: int,
138
+ eps: float = 1e-5,
139
+ device=None,
140
+ dtype=None,
141
+ ):
142
+ super(BalancedBasicNorm, self).__init__()
143
+ self.balancer = ActivationBalancer(
144
+ d_model,
145
+ channel_dim=-1,
146
+ min_positive=0.45,
147
+ max_positive=0.55,
148
+ max_abs=6.0,
149
+ )
150
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
151
+
152
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
153
+ if isinstance(input, tuple):
154
+ input, embedding = input
155
+ return self.norm((self.balancer(input), embedding))
156
+
157
+ assert embedding is None
158
+ return self.norm(self.balancer(input))
159
+
160
+
161
+ class IdentityNorm(nn.Module):
162
+ def __init__(
163
+ self,
164
+ d_model: int,
165
+ eps: float = 1e-5,
166
+ device=None,
167
+ dtype=None,
168
+ ) -> None:
169
+ super(IdentityNorm, self).__init__()
170
+
171
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
172
+ if isinstance(input, tuple):
173
+ return input
174
+
175
+ assert embedding is None
176
+ return input
177
+
178
+
179
+ class TransformerEncoderLayer(nn.Module):
180
+ __constants__ = ["batch_first", "norm_first"]
181
+
182
+ def __init__(
183
+ self,
184
+ d_model: int,
185
+ nhead: int,
186
+ dim_feedforward: int = 2048,
187
+ dropout: float = 0.1,
188
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
189
+ batch_first: bool = False,
190
+ norm_first: bool = False,
191
+ device=None,
192
+ dtype=None,
193
+ linear1_self_attention_cls: nn.Module = nn.Linear,
194
+ linear2_self_attention_cls: nn.Module = nn.Linear,
195
+ linear1_feedforward_cls: nn.Module = nn.Linear,
196
+ linear2_feedforward_cls: nn.Module = nn.Linear,
197
+ layer_norm_cls: nn.Module = LayerNorm,
198
+ layer_norm_eps: float = 1e-5,
199
+ adaptive_layer_norm=False,
200
+ ) -> None:
201
+ factory_kwargs = {"device": device, "dtype": dtype}
202
+ super(TransformerEncoderLayer, self).__init__()
203
+ self.self_attn = MultiheadAttention(
204
+ d_model,
205
+ nhead,
206
+ dropout=dropout,
207
+ batch_first=batch_first,
208
+ linear1_cls=linear1_self_attention_cls,
209
+ linear2_cls=linear2_self_attention_cls,
210
+ **factory_kwargs,
211
+ )
212
+
213
+ # Implementation of Feedforward model
214
+ self.linear1 = linear1_feedforward_cls(
215
+ d_model, dim_feedforward, **factory_kwargs
216
+ )
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = linear2_feedforward_cls(
219
+ dim_feedforward, d_model, **factory_kwargs
220
+ )
221
+
222
+ self.norm_first = norm_first
223
+ self.dropout1 = nn.Dropout(dropout)
224
+ self.dropout2 = nn.Dropout(dropout)
225
+
226
+ # Legacy string support for activation function.
227
+ if isinstance(activation, str):
228
+ activation = _get_activation_fn(activation)
229
+ elif isinstance(activation, partial):
230
+ activation = activation(d_model)
231
+ elif activation == BalancedDoubleSwish:
232
+ activation = BalancedDoubleSwish(d_model)
233
+
234
+ # # We can't test self.activation in forward() in TorchScript,
235
+ # # so stash some information about it instead.
236
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
237
+ # self.activation_relu_or_gelu = 1
238
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
239
+ # self.activation_relu_or_gelu = 2
240
+ # else:
241
+ # self.activation_relu_or_gelu = 0
242
+ self.activation = activation
243
+
244
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
245
+ if layer_norm_cls == IdentityNorm:
246
+ norm2 = BalancedBasicNorm(
247
+ d_model, eps=layer_norm_eps, **factory_kwargs
248
+ )
249
+ else:
250
+ norm2 = layer_norm_cls(
251
+ d_model, eps=layer_norm_eps, **factory_kwargs
252
+ )
253
+
254
+ if adaptive_layer_norm:
255
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
256
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
257
+ else:
258
+ self.norm1 = norm1
259
+ self.norm2 = norm2
260
+
261
+ def __setstate__(self, state):
262
+ super(TransformerEncoderLayer, self).__setstate__(state)
263
+ if not hasattr(self, "activation"):
264
+ self.activation = F.relu
265
+
266
+ def forward(
267
+ self,
268
+ src: Tensor,
269
+ src_mask: Optional[Tensor] = None,
270
+ src_key_padding_mask: Optional[Tensor] = None,
271
+ need_weights: Optional[bool] = False,
272
+ past: Optional[Tensor] = None,
273
+ ) -> Tensor:
274
+ r"""Pass the input through the encoder layer.
275
+ Args:
276
+ src: the sequence to the encoder layer (required).
277
+ src_mask: the mask for the src sequence (optional).
278
+ src_key_padding_mask: the mask for the src keys per batch (optional).
279
+ Shape:
280
+ see the docs in Transformer class.
281
+ """
282
+ x, stage_embedding = src, None
283
+ is_src_tuple = False
284
+ if isinstance(src, tuple):
285
+ x, stage_embedding = src
286
+ is_src_tuple = True
287
+
288
+ if src_key_padding_mask is not None:
289
+ _skpm_dtype = src_key_padding_mask.dtype
290
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
291
+ src_key_padding_mask
292
+ ):
293
+ raise AssertionError(
294
+ "only bool and floating types of key_padding_mask are supported"
295
+ )
296
+ if need_weights:
297
+ if self.norm_first:
298
+ out, attn = self._sa_block_attn(
299
+ self.norm1(x, stage_embedding),
300
+ src_mask,
301
+ src_key_padding_mask,
302
+ past
303
+ )
304
+ out, present = out # present is the kvcache of the present timestep
305
+ x = x + out
306
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
307
+ else:
308
+ out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
309
+ out, present = out # present is the kvcache of the present timestep
310
+ x = self.norm1(
311
+ x + out,
312
+ stage_embedding,
313
+ )
314
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
315
+ assert not is_src_tuple
316
+ # return (x, stage_embedding)
317
+ return (x, attn)
318
+ else:
319
+ if self.norm_first:
320
+ out = self._sa_block(
321
+ self.norm1(x, stage_embedding),
322
+ src_mask,
323
+ src_key_padding_mask, past
324
+ )
325
+ out, present = out # present is the kvcache of the present timestep
326
+ x = x + out
327
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
328
+ else:
329
+ out = self._sa_block(x, src_mask, src_key_padding_mask)
330
+ out, present = out # present is the kvcache of the present timestep
331
+ x = self.norm1(
332
+ x + out,
333
+ stage_embedding, past
334
+ )
335
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
336
+
337
+ if is_src_tuple:
338
+ x = (x, stage_embedding)
339
+ if present != None:
340
+ x = [x, present]
341
+ return x
342
+
343
+ # self-attention block
344
+ def _sa_block(
345
+ self,
346
+ x: Tensor,
347
+ attn_mask: Optional[Tensor],
348
+ key_padding_mask: Optional[Tensor],
349
+ past: Optional[Tensor] = None,
350
+ ) -> Tensor:
351
+ x = self.self_attn(
352
+ x,
353
+ x,
354
+ x,
355
+ attn_mask=attn_mask,
356
+ key_padding_mask=key_padding_mask,
357
+ need_weights=False,
358
+ past=past
359
+ )
360
+ x, present = x
361
+ return self.dropout1(x), present
362
+
363
+ # self-attention block, also return attention weights
364
+ def _sa_block_attn(
365
+ self,
366
+ x: Tensor,
367
+ attn_mask: Optional[Tensor],
368
+ key_padding_mask: Optional[Tensor],
369
+ past: Optional[Tensor] = None,
370
+ ) -> Tensor:
371
+ x, attn = self.self_attn(
372
+ x,
373
+ x,
374
+ x,
375
+ attn_mask=attn_mask,
376
+ key_padding_mask=key_padding_mask,
377
+ need_weights=True,
378
+ past=past
379
+ )
380
+ x, present = x
381
+ return (self.dropout1(x), present), attn
382
+
383
+ # feed forward block
384
+ def _ff_block(self, x: Tensor) -> Tensor:
385
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
386
+ return self.dropout2(x)
387
+
388
+
389
+ class TransformerEncoder(nn.Module):
390
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
391
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
392
+ Args:
393
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
394
+ num_layers: the number of sub-encoder-layers in the encoder (required).
395
+ norm: the layer normalization component (optional).
396
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
397
+ (and convert back on output). This will improve the overall performance of
398
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
399
+ Examples::
400
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
401
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
402
+ >>> src = torch.rand(10, 32, 512)
403
+ >>> out = transformer_encoder(src)
404
+ """
405
+ __constants__ = ["norm"]
406
+
407
+ def __init__(self, encoder_layer, num_layers, norm=None):
408
+ super(TransformerEncoder, self).__init__()
409
+ self.layers = _get_clones(encoder_layer, num_layers)
410
+ self.num_layers = num_layers
411
+ self.norm = norm
412
+
413
+ def forward(
414
+ self,
415
+ src: Tensor,
416
+ mask: Optional[Tensor] = None,
417
+ src_key_padding_mask: Optional[Tensor] = None,
418
+ return_layer_states: bool = False,
419
+ need_weights:Optional[bool] = False,
420
+ past: Optional[Tensor] = None,
421
+ ) -> Tensor:
422
+ r"""Pass the input through the encoder layers in turn.
423
+ Args:
424
+ src: the sequence to the encoder (required).
425
+ mask: the mask for the src sequence (optional).
426
+ src_key_padding_mask: the mask for the src keys per batch (optional).
427
+ return_layer_states: return layers' state (optional).
428
+ Shape:
429
+ see the docs in Transformer class.
430
+ """
431
+ if return_layer_states:
432
+ assert not need_weights
433
+ layer_states = [] # layers' output
434
+ output = src
435
+ for mod in self.layers:
436
+ output = mod(
437
+ output,
438
+ src_mask=mask,
439
+ src_key_padding_mask=src_key_padding_mask,
440
+ past=past
441
+ )
442
+ layer_states.append(output[0])
443
+
444
+ if self.norm is not None:
445
+ output = self.norm(output)
446
+
447
+ return layer_states, output
448
+ if need_weights:
449
+ assert not return_layer_states
450
+ layer_attn = [] # layers' output
451
+ output = src
452
+ for mod in self.layers:
453
+ output = mod(
454
+ output,
455
+ src_mask=mask,
456
+ src_key_padding_mask=src_key_padding_mask,
457
+ need_weights=True,
458
+ past=past
459
+ )
460
+ layer_attn.append(output[1])
461
+
462
+ if self.norm is not None:
463
+ output = self.norm(output)
464
+
465
+ return layer_attn, output
466
+
467
+ output = src
468
+ all_present = []
469
+ for n_layer, mod in enumerate(self.layers):
470
+ output = mod(
471
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
472
+ )
473
+ if isinstance(output, list):
474
+ output, present = output
475
+ all_present.append(present)
476
+
477
+ if self.norm is not None:
478
+ output = self.norm(output)
479
+ if all_present != []:
480
+ all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
481
+ output = [output, all_present]
482
+ return output
483
+
484
+
485
+ class TransformerDecoderLayer(nn.Module):
486
+ __constants__ = ["batch_first", "norm_first"]
487
+
488
+ def __init__(
489
+ self,
490
+ d_model: int,
491
+ nhead: int,
492
+ dim_feedforward: int = 2048,
493
+ dropout: float = 0.1,
494
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
495
+ linear1_self_attention_cls: nn.Module = nn.Linear,
496
+ linear2_self_attention_cls: nn.Module = nn.Linear,
497
+ linear1_feedforward_cls: nn.Module = nn.Linear,
498
+ linear2_feedforward_cls: nn.Module = nn.Linear,
499
+ batch_first: bool = False,
500
+ norm_first: bool = False,
501
+ device=None,
502
+ dtype=None,
503
+ layer_norm_cls: nn.Module = LayerNorm,
504
+ layer_norm_eps: float = 1e-5,
505
+ adaptive_layer_norm=False,
506
+ ) -> None:
507
+ factory_kwargs = {"device": device, "dtype": dtype}
508
+ super(TransformerDecoderLayer, self).__init__()
509
+ self.self_attn = MultiheadAttention(
510
+ d_model,
511
+ nhead,
512
+ dropout=dropout,
513
+ batch_first=batch_first,
514
+ linear1_cls=linear1_self_attention_cls,
515
+ linear2_cls=linear2_self_attention_cls,
516
+ **factory_kwargs,
517
+ )
518
+ self.multihead_attn = MultiheadAttention(
519
+ d_model,
520
+ nhead,
521
+ dropout=dropout,
522
+ batch_first=batch_first,
523
+ linear1_cls=linear1_self_attention_cls,
524
+ linear2_cls=linear2_self_attention_cls,
525
+ **factory_kwargs,
526
+ )
527
+ # Implementation of Feedforward model
528
+ self.linear1 = linear1_feedforward_cls(
529
+ d_model, dim_feedforward, **factory_kwargs
530
+ )
531
+ self.dropout = nn.Dropout(dropout)
532
+ self.linear2 = linear2_feedforward_cls(
533
+ dim_feedforward, d_model, **factory_kwargs
534
+ )
535
+
536
+ self.norm_first = norm_first
537
+ self.dropout1 = nn.Dropout(dropout)
538
+ self.dropout2 = nn.Dropout(dropout)
539
+ self.dropout3 = nn.Dropout(dropout)
540
+
541
+ # Legacy string support for activation function.
542
+ if isinstance(activation, str):
543
+ self.activation = _get_activation_fn(activation)
544
+ elif isinstance(activation, partial):
545
+ self.activation = activation(d_model)
546
+ elif activation == BalancedDoubleSwish:
547
+ self.activation = BalancedDoubleSwish(d_model)
548
+ else:
549
+ self.activation = activation
550
+
551
+ if adaptive_layer_norm:
552
+ norm1 = layer_norm_cls(
553
+ d_model, eps=layer_norm_eps, **factory_kwargs
554
+ )
555
+ norm2 = layer_norm_cls(
556
+ d_model, eps=layer_norm_eps, **factory_kwargs
557
+ )
558
+ norm3 = layer_norm_cls(
559
+ d_model, eps=layer_norm_eps, **factory_kwargs
560
+ )
561
+
562
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
563
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
564
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
565
+ else:
566
+ self.norm1 = layer_norm_cls(
567
+ d_model, eps=layer_norm_eps, **factory_kwargs
568
+ )
569
+ self.norm2 = layer_norm_cls(
570
+ d_model, eps=layer_norm_eps, **factory_kwargs
571
+ )
572
+ if layer_norm_cls == IdentityNorm:
573
+ self.norm3 = BalancedBasicNorm(
574
+ d_model, eps=layer_norm_eps, **factory_kwargs
575
+ )
576
+ else:
577
+ self.norm3 = layer_norm_cls(
578
+ d_model, eps=layer_norm_eps, **factory_kwargs
579
+ )
580
+
581
+ def forward(
582
+ self,
583
+ tgt: Tensor,
584
+ memory: Tensor,
585
+ tgt_mask: Optional[Tensor] = None,
586
+ memory_mask: Optional[Tensor] = None,
587
+ tgt_key_padding_mask: Optional[Tensor] = None,
588
+ memory_key_padding_mask: Optional[Tensor] = None,
589
+ ) -> Tensor:
590
+ r"""Pass the inputs (and mask) through the decoder layer.
591
+ Args:
592
+ tgt: the sequence to the decoder layer (required).
593
+ memory: the sequence from the last layer of the encoder (required).
594
+ tgt_mask: the mask for the tgt sequence (optional).
595
+ memory_mask: the mask for the memory sequence (optional).
596
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
597
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
598
+ Shape:
599
+ see the docs in Transformer class.
600
+ """
601
+ tgt_is_tuple = False
602
+ if isinstance(tgt, tuple):
603
+ x, stage_embedding = tgt
604
+ tgt_is_tuple = True
605
+ else:
606
+ x, stage_embedding = tgt, None
607
+
608
+ if self.norm_first:
609
+ x = x + self._sa_block(
610
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
611
+ )
612
+ x = x + self._mha_block(
613
+ self.norm2(x, stage_embedding),
614
+ memory,
615
+ memory_mask,
616
+ memory_key_padding_mask,
617
+ )
618
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
619
+ else:
620
+ x = self.norm1(
621
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
622
+ stage_embedding,
623
+ )
624
+ x = self.norm2(
625
+ x
626
+ + self._mha_block(
627
+ x, memory, memory_mask, memory_key_padding_mask
628
+ ),
629
+ stage_embedding,
630
+ )
631
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
632
+
633
+ if tgt_is_tuple:
634
+ return (x, stage_embedding)
635
+ return x
636
+
637
+ # self-attention block
638
+ def _sa_block(
639
+ self,
640
+ x: Tensor,
641
+ attn_mask: Optional[Tensor],
642
+ key_padding_mask: Optional[Tensor],
643
+ ) -> Tensor:
644
+ x = self.self_attn(
645
+ x,
646
+ x,
647
+ x,
648
+ attn_mask=attn_mask,
649
+ key_padding_mask=key_padding_mask,
650
+ need_weights=False,
651
+ )[0]
652
+ return self.dropout1(x)
653
+
654
+ # multihead attention block
655
+ def _mha_block(
656
+ self,
657
+ x: Tensor,
658
+ mem: Tensor,
659
+ attn_mask: Optional[Tensor],
660
+ key_padding_mask: Optional[Tensor],
661
+ ) -> Tensor:
662
+ x = self.multihead_attn(
663
+ x,
664
+ mem,
665
+ mem,
666
+ attn_mask=attn_mask,
667
+ key_padding_mask=key_padding_mask,
668
+ need_weights=False,
669
+ )[0]
670
+ return self.dropout2(x)
671
+
672
+ # feed forward block
673
+ def _ff_block(self, x: Tensor) -> Tensor:
674
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
675
+ return self.dropout3(x)
676
+
677
+
678
+ def _get_clones(module, N):
679
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
680
+
681
+
682
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
683
+ if activation == "relu":
684
+ return F.relu
685
+ elif activation == "gelu":
686
+ return F.gelu
687
+
688
+ raise RuntimeError(
689
+ "activation should be relu/gelu, not {}".format(activation)
690
+ )
src/model/modules/voicecraft.py ADDED
@@ -0,0 +1,1999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/jasonppy/VoiceCraft/blob/master/models/voicecraft.py
2
+
3
+ import random
4
+
5
+ import numpy as np
6
+ import logging
7
+ import argparse, copy
8
+ from typing import Dict, Optional
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchmetrics.classification import MulticlassAccuracy
13
+
14
+ from .codebooks_patterns import DelayedPatternProvider
15
+
16
+ from ...utils.util import make_pad_mask
17
+
18
+ from .embedding import SinePositionalEmbedding, TokenEmbedding
19
+ from .transformer import (
20
+ LayerNorm,
21
+ TransformerEncoder,
22
+ TransformerEncoderLayer,
23
+ )
24
+
25
+ from argparse import Namespace
26
+ from huggingface_hub import PyTorchModelHubMixin
27
+
28
+
29
+ def top_k_top_p_filtering(
30
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
31
+ ):
32
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
33
+ Args:
34
+ logits: logits distribution shape (batch size, vocabulary size)
35
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
36
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
37
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
38
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
39
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
40
+ """
41
+ if top_k > 0:
42
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
43
+ # Remove all tokens with a probability less than the last token of the top-k
44
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
45
+ logits[indices_to_remove] = filter_value
46
+
47
+ if top_p < 1.0:
48
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
49
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
50
+
51
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
52
+ sorted_indices_to_remove = cumulative_probs > top_p
53
+ if min_tokens_to_keep > 1:
54
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
55
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
56
+ # Shift the indices to the right to keep also the first token above the threshold
57
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
58
+ sorted_indices_to_remove[..., 0] = 0
59
+
60
+ # scatter sorted tensors to original indexing
61
+ indices_to_remove = sorted_indices_to_remove.scatter(
62
+ 1, sorted_indices, sorted_indices_to_remove
63
+ )
64
+ logits[indices_to_remove] = filter_value
65
+ return logits
66
+
67
+
68
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
69
+ # temperature: (`optional`) float
70
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
71
+ # top_k: (`optional`) int
72
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
73
+ # top_p: (`optional`) float
74
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
75
+
76
+ # Temperature (higher temperature => more likely to sample low probability tokens)
77
+ if temperature != 1.0:
78
+ logits = logits / temperature
79
+ # Top-p/top-k filtering
80
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
81
+ # Sample
82
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
83
+ return token
84
+
85
+
86
+ class VoiceCraft(
87
+ nn.Module,
88
+ PyTorchModelHubMixin,
89
+ library_name="voicecraft",
90
+ repo_url="https://github.com/jasonppy/VoiceCraft",
91
+ tags=["text-to-speech"],
92
+ ):
93
+ def __new__(
94
+ cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs
95
+ ) -> "VoiceCraft":
96
+ # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
97
+ # Won't affect instance initialization
98
+ if args is not None:
99
+ if config is not None:
100
+ raise ValueError("Cannot provide both `args` and `config`.")
101
+ config = vars(args)
102
+ return super().__new__(cls, args=args, config=config, **kwargs)
103
+
104
+ def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
105
+ super().__init__()
106
+
107
+ # If loaded from HF Hub => convert config.json to Namespace args before initializing
108
+ if args is None:
109
+ if config is None:
110
+ raise ValueError("Either `args` or `config` must be provided.")
111
+ args = Namespace(**config)
112
+
113
+ self.args = copy.copy(args)
114
+ self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
115
+ if not getattr(self.args, "special_first", False):
116
+ self.args.special_first = 0
117
+ if not getattr(self.args, "n_special", False):
118
+ self.args.n_special = 3
119
+ self.args.eos = getattr(self.args, "eos", -1)
120
+ self.eog = nn.Parameter(
121
+ torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long),
122
+ requires_grad=False,
123
+ ) # [K 1]
124
+ if self.args.eos > 0:
125
+ assert (
126
+ self.args.eos != self.args.audio_pad_token
127
+ and self.args.eos != self.args.empty_token
128
+ ), self.args.eos
129
+ self.eos = nn.Parameter(
130
+ torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long),
131
+ requires_grad=False,
132
+ ) # [K 1]
133
+ if isinstance(self.args.audio_vocab_size, str):
134
+ self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
135
+
136
+ self.n_text_tokens = self.args.text_vocab_size + 1
137
+ assert (
138
+ self.args.text_pad_token == self.args.text_vocab_size
139
+ ), f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
140
+
141
+ self.n_audio_tokens = [
142
+ self.args.audio_vocab_size + self.args.n_special
143
+ ] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
144
+ assert (
145
+ self.args.audio_vocab_size == self.args.empty_token
146
+ ), self.args.empty_token
147
+ assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
148
+ assert (
149
+ self.args.audio_pad_token == self.args.audio_vocab_size + 2
150
+ ), self.args.audio_pad_token
151
+
152
+ self.text_embedding = TokenEmbedding(
153
+ dim_model=self.args.d_model,
154
+ vocab_size=self.n_text_tokens,
155
+ dropout=self.args.text_embedding_dropout,
156
+ )
157
+
158
+ self.audio_embedding = nn.ModuleList(
159
+ [
160
+ TokenEmbedding(
161
+ dim_model=self.args.audio_embedding_dim,
162
+ vocab_size=self.n_audio_tokens[k],
163
+ dropout=self.args.audio_embedding_dropout,
164
+ )
165
+ for k in range(self.args.n_codebooks)
166
+ ]
167
+ )
168
+ self.mask_embedding = nn.Parameter(
169
+ torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True
170
+ )
171
+ self.text_positional_embedding = SinePositionalEmbedding(
172
+ self.args.d_model,
173
+ dropout=self.args.text_positional_embedding_dropout,
174
+ scale=False,
175
+ alpha=True, # learnable scaler, scale the volume of positional embedding
176
+ )
177
+ self.audio_positional_embedding = SinePositionalEmbedding(
178
+ self.args.d_model,
179
+ dropout=self.args.audio_positional_embedding_dropout,
180
+ scale=False,
181
+ alpha=True, # learnable scaler, scale the volume of positional embedding
182
+ )
183
+
184
+ dec_layer = TransformerEncoderLayer(
185
+ self.args.d_model,
186
+ self.args.nhead,
187
+ dim_feedforward=self.args.d_model * 4,
188
+ dropout=self.args.trm_dropout,
189
+ batch_first=True,
190
+ norm_first=True,
191
+ layer_norm_cls=LayerNorm,
192
+ )
193
+ self.decoder = TransformerEncoder(
194
+ dec_layer,
195
+ num_layers=self.args.num_decoder_layers,
196
+ norm=LayerNorm(self.args.d_model),
197
+ )
198
+
199
+ self.predict_layer = nn.ModuleList(
200
+ [
201
+ nn.Sequential(
202
+ nn.Linear(self.args.d_model, self.args.audio_vocab_size // 2),
203
+ nn.GELU(),
204
+ nn.Linear(self.args.audio_vocab_size // 2, self.n_audio_tokens[k]),
205
+ )
206
+ for k in range(self.args.n_codebooks)
207
+ ]
208
+ )
209
+
210
+ self.accuracy_metrics = nn.ModuleList(
211
+ [
212
+ MulticlassAccuracy(
213
+ self.n_audio_tokens[k],
214
+ top_k=10,
215
+ average="micro",
216
+ multidim_average="global",
217
+ ignore_index=None,
218
+ )
219
+ for k in range(self.args.n_codebooks)
220
+ ]
221
+ )
222
+
223
+ def prepare_mask_intervals(self, y_lens):
224
+ mask_intervals = []
225
+ non_mask_intervals = []
226
+
227
+ for i, y_len in enumerate(y_lens):
228
+ if self.args.mask_sample_dist == "uniform":
229
+ n_spans = random.choice(range(1, self.args.max_n_spans + 1))
230
+ elif "poisson" in self.args.mask_sample_dist.lower():
231
+ param = float(self.args.mask_sample_dist[len("poisson") :])
232
+ poisson_sample = torch.poisson(torch.tensor([param]))
233
+ n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
234
+
235
+ starts = random.sample(
236
+ range(1, y_len - 1 - self.args.mask_len_min), n_spans
237
+ )
238
+ starts = sorted(starts)
239
+
240
+ for j in range(len(starts) - 1, 0, -1):
241
+ if starts[j] - starts[j - 1] < self.args.min_gap:
242
+ del starts[j] # If elements are too close, delete the later one
243
+ assert (
244
+ len(starts) > 0
245
+ ), f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
246
+
247
+ temp_starts = starts + [y_len]
248
+ gaps = [
249
+ temp_starts[j + 1] - temp_starts[j] for j in range(len(temp_starts) - 1)
250
+ ]
251
+
252
+ ends = []
253
+
254
+ for j, (start, gap) in enumerate(zip(starts, gaps)):
255
+ mask_len = random.randint(
256
+ self.args.mask_len_min, self.args.mask_len_max
257
+ )
258
+ # if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other
259
+ if (
260
+ mask_len > gap - 1
261
+ ): # make sure the masks are not overlapping with each other
262
+ # temp_mask_start = int(0.6*gap*self.args.max_mask_portion)
263
+ # temp_mask_end = int(gap*self.args.max_mask_portion)
264
+ temp_mask_start = 1
265
+ temp_mask_end = gap - 1
266
+ mask_len = random.randint(temp_mask_start, temp_mask_end)
267
+ ends.append(start + mask_len)
268
+
269
+ mask_intervals.append([(s, e) for s, e in zip(starts, ends)])
270
+ non_mask_intervals.append(
271
+ [(ns, ne) for ns, ne in zip([0] + ends, starts + [y_len])]
272
+ )
273
+
274
+ return mask_intervals, non_mask_intervals
275
+
276
+ def rearrange(self, y, non_mask_intervals, mask_intervals):
277
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
278
+ rearranged_y = []
279
+ for i in range(len(y)):
280
+ if self.args.eos > 0:
281
+ assert reduced_eog
282
+ cur_y = (
283
+ [y[i, :, item[0] : item[1]] for item in non_mask_intervals[i][:-1]]
284
+ + [
285
+ torch.cat(
286
+ [
287
+ y[
288
+ i,
289
+ :,
290
+ non_mask_intervals[i][-1][0] : non_mask_intervals[
291
+ i
292
+ ][-1][1],
293
+ ],
294
+ self.eos,
295
+ ],
296
+ dim=-1,
297
+ )
298
+ ]
299
+ + [
300
+ torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
301
+ for item in mask_intervals[i]
302
+ ]
303
+ ) # only insert eog to the last non-mask-interval, which is when the utterance actual ends
304
+ else:
305
+ if reduced_eog:
306
+ cur_y = (
307
+ [
308
+ y[i, :, item[0] : item[1]]
309
+ for item in non_mask_intervals[i][:-1]
310
+ ]
311
+ + [
312
+ torch.cat(
313
+ [
314
+ y[
315
+ i,
316
+ :,
317
+ non_mask_intervals[i][-1][
318
+ 0
319
+ ] : non_mask_intervals[i][-1][1],
320
+ ],
321
+ self.eog,
322
+ ],
323
+ dim=-1,
324
+ )
325
+ ]
326
+ + [
327
+ torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
328
+ for item in mask_intervals[i]
329
+ ]
330
+ ) # only insert eog to the last non-mask-interval, which is when the utterance actual ends
331
+ else:
332
+ cur_y = [
333
+ torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
334
+ for item in non_mask_intervals[i]
335
+ ] + [
336
+ torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
337
+ for item in mask_intervals[i]
338
+ ] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
339
+ rearranged_y.append(cur_y)
340
+ return rearranged_y
341
+
342
+ def shift(self, rearranged_y):
343
+ shifted_y = []
344
+ patterns = []
345
+ for i in range(len(rearranged_y)):
346
+ cur_patterns = [
347
+ self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]
348
+ ]
349
+ out = [
350
+ cur_pattern.build_pattern_sequence(
351
+ z=cur_y.unsqueeze(0).contiguous(),
352
+ special_token=self.args.empty_token,
353
+ keep_only_valid_steps=False,
354
+ )
355
+ for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])
356
+ ]
357
+ shifted_y.append(
358
+ [item[0].squeeze(0) for item in out]
359
+ ) # the first item is values, later two are indexes and mask
360
+ patterns.append(cur_patterns)
361
+ return shifted_y, patterns
362
+
363
+ def insert_mask(self, shifted_y):
364
+ inserted_y = []
365
+ mask_position = []
366
+ mask_value = []
367
+ for i in range(len(shifted_y)):
368
+ num_masks = (len(shifted_y[i]) - 1) // 2
369
+ assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i])
370
+ emb_inds = list(range(self.args.max_n_spans))
371
+ if self.args.shuffle_mask_embedding:
372
+ random.shuffle(emb_inds)
373
+ emb_inds_use = emb_inds[:num_masks]
374
+ emb_inds_use = emb_inds_use + emb_inds_use
375
+ mask_value.append(emb_inds_use)
376
+ cur_inserted_y = []
377
+ cur_mask_position = []
378
+ for j in range(len(shifted_y[i]) - 1):
379
+ cur_inserted_y.append(shifted_y[i][j])
380
+ cur_mask_position.append(
381
+ sum([item.shape[1] for item in cur_inserted_y])
382
+ ) # each item is of shape [K S], so take shape[1]
383
+ cur_inserted_y.append(
384
+ self.eog
385
+ ) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function
386
+
387
+ cur_inserted_y.append(shifted_y[i][-1])
388
+
389
+ inserted_y.append(cur_inserted_y)
390
+ mask_position.append(cur_mask_position)
391
+ return inserted_y, mask_position, mask_value
392
+
393
+ def cat_y(self, inserted_y, mask_position, y_lens):
394
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
395
+ cated_y = []
396
+ new_y_lens = []
397
+ for i in range(len(inserted_y)):
398
+ cur_cated_y = torch.cat(inserted_y[i], dim=1) # [K S]
399
+ cur_cated_y = cur_cated_y.transpose(1, 0) # [S K]
400
+ cur_cated_y_len = cur_cated_y.shape[0]
401
+ if reduced_eog:
402
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (
403
+ len(mask_position[i]) + 1
404
+ ) * self.args.n_codebooks + (
405
+ len(mask_position[i]) / 2 + 1
406
+ ), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}"
407
+ else:
408
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (
409
+ len(mask_position[i]) + 1
410
+ ) * self.args.n_codebooks + (
411
+ len(mask_position[i]) + 1
412
+ ), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong
413
+ new_y_lens.append(cur_cated_y_len)
414
+ cated_y.append(cur_cated_y)
415
+
416
+ cated_y = torch.nn.utils.rnn.pad_sequence(
417
+ cated_y, batch_first=False, padding_value=self.args.audio_pad_token
418
+ )
419
+ assert cated_y.shape == torch.Size(
420
+ [max(new_y_lens), len(inserted_y), self.args.n_codebooks]
421
+ ), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}"
422
+ cated_y = cated_y.permute(2, 0, 1) # [T,B,K]->[K,T,B]
423
+ assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
424
+ return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device)
425
+
426
+ def embed_y(self, cated_y, mask_position, mask_value):
427
+ embedded_y = torch.stack(
428
+ [self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
429
+ dim=0,
430
+ ) # [K, T, B, D]
431
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
432
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
433
+ embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
434
+ embedded_y = embedded_y.transpose(1, 0) # [T,B,D]->[B,T,D]
435
+ for i in range(len(embedded_y)):
436
+ if len(mask_position[i]) > 0:
437
+ embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
438
+ return embedded_y
439
+
440
+ def prepare_input_target(self, y, y_lens):
441
+ # rearrange y
442
+ # assume y shape: [B T K], K is n_codebooks
443
+ assert y.shape[1] == self.args.n_codebooks, y.shape
444
+ # sample mask_intervals
445
+ mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens)
446
+
447
+ # need to have EOG in each section (SOG will be generated by the pattern class)
448
+ # but mask can be inserted later after we have shifted the input
449
+ # y could be rearranged in this way:
450
+ # [
451
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
452
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
453
+ # ...
454
+ # ]
455
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
456
+ # NOTE #non_masked_part = #masked_part + 1
457
+ # NOTE *these are also the targets*
458
+ # added eog at the end of each segment (masked segment and unmasked segment)
459
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
460
+ targets = rearranged_y # each element in each sample is of shape [K T]
461
+ assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape
462
+
463
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
464
+ # [[5, 1, 2, 3, 4, 5, 5],
465
+ # [5, 5, 1, 2, 3, 4, 5],
466
+ # [5, 5, 5, 1, 2, 3, 4]]
467
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S]
468
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[
469
+ 0
470
+ ]
471
+
472
+ # then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it
473
+ # we also need to record the position of the inserted mask
474
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
475
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][
476
+ 0
477
+ ].shape[0]
478
+ assert inserted_y[0][1].shape == torch.Size(
479
+ (self.args.n_codebooks, 1)
480
+ ), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
481
+
482
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
483
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
484
+ assert cated_y.shape == torch.Size(
485
+ (self.args.n_codebooks, cated_y.shape[1], len(inserted_y))
486
+ )
487
+
488
+ # embed remember to separately embed the mask tokens
489
+ embedded_y = self.embed_y(cated_y, mask_position, mask_value) # BTD
490
+ assert embedded_y.shape[1:] == torch.Size(
491
+ (max(new_y_lens), self.args.d_model)
492
+ ), embedded_y.shape
493
+
494
+ # positional embedding
495
+ y_input = self.audio_positional_embedding(embedded_y)
496
+
497
+ # make attention mask and padding mask
498
+ y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
499
+ y_attention_mask = (
500
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
501
+ .bool()
502
+ .to(y_padding_mask.device)
503
+ )
504
+ return (
505
+ y_input,
506
+ new_y_lens,
507
+ targets,
508
+ y_padding_mask,
509
+ y_attention_mask,
510
+ mask_position,
511
+ patterns,
512
+ )
513
+
514
+ def remove_mask(self, logits, mask_position, new_y_lens):
515
+ # logits: [B K S card]
516
+ logits_use = []
517
+ for i in range(len(logits)):
518
+ non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]]
519
+ non_mask_intervals = [
520
+ [non_mask_positions[i] + 1, non_mask_positions[i + 1]]
521
+ for i in range(len(non_mask_positions) - 1)
522
+ ]
523
+ cur_logits_use = [logits[i, :, l:r] for l, r in non_mask_intervals]
524
+ logits_use.append(cur_logits_use)
525
+
526
+ return logits_use
527
+
528
+ def revert_pattern(self, patterns, logits_use):
529
+ logits_final = []
530
+ logit_masks = []
531
+ for i in range(len(logits_use)):
532
+ cur_logits = [
533
+ item.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
534
+ for item in logits_use[i]
535
+ ] # each item is of shape [1 K S card] [1 card K S]
536
+ cur_logits_final = [
537
+ cur_pattern.revert_pattern_logits(item, 0, keep_only_valid_steps=False)
538
+ for cur_pattern, item in zip(patterns[i], cur_logits)
539
+ ] # if input output order doesn't match, this step will give an error
540
+ cur_logits_final_ret = [
541
+ item[0].permute(0, 2, 3, 1).squeeze(0) for item in cur_logits_final
542
+ ] # each element is of shape [K,T,card]
543
+ logits_final.append(cur_logits_final_ret)
544
+ logit_masks.append([item[2] for item in cur_logits_final])
545
+
546
+ return logits_final, logit_masks
547
+
548
+ def dec_forward(
549
+ self,
550
+ x_input,
551
+ x_lens,
552
+ x_attention_mask,
553
+ x_padding_mask,
554
+ y_input,
555
+ new_y_lens,
556
+ y_attention_mask,
557
+ y_padding_mask,
558
+ past=None,
559
+ last_3_tokens=False,
560
+ ):
561
+ x_attn_mask = F.pad(
562
+ x_attention_mask,
563
+ (0, new_y_lens.max()),
564
+ value=True,
565
+ ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
566
+ y_attn_mask = F.pad(
567
+ y_attention_mask,
568
+ (x_lens.max(), 0), # y is padded at the front
569
+ value=False,
570
+ ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
571
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
572
+
573
+ # merge key padding and attention masks
574
+ bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
575
+ xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
576
+ _xy_padding_mask = (
577
+ xy_padding_mask.view(bsz, 1, 1, src_len)
578
+ .expand(-1, self.args.nhead, -1, -1)
579
+ .reshape(bsz * self.args.nhead, 1, src_len)
580
+ )
581
+ # Check shapes and resize+broadcast as necessary
582
+ if xy_attn_mask.shape != _xy_padding_mask.shape:
583
+ assert (
584
+ xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim
585
+ ), f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
586
+ xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(
587
+ _xy_padding_mask.shape[0], 1, 1
588
+ ) # Example approach
589
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
590
+
591
+ new_attn_mask = torch.zeros_like(xy_attn_mask)
592
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
593
+ xy_attn_mask = new_attn_mask
594
+
595
+ xy_input = torch.cat([x_input, y_input], dim=1)
596
+
597
+ if past == None: # do not use kvcache
598
+ out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
599
+ return out[:, x_lens.max() :], None
600
+ else: # use kvcache
601
+ if (
602
+ past.ndim > 3
603
+ ): # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
604
+ if last_3_tokens:
605
+ xy_input = xy_input[:, -3:]
606
+ xy_attn_mask = xy_attn_mask[:, -3:]
607
+ else:
608
+ xy_input = xy_input[:, -1:]
609
+ xy_attn_mask = xy_attn_mask[:, -1:]
610
+
611
+ out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
612
+ if isinstance(out, tuple): # get rid of stage_embedding
613
+ out = out[0]
614
+
615
+ if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
616
+ return out[:, x_lens.max() :], present
617
+ else: # used kvcache
618
+ return out, present
619
+
620
+ def forward(self, batch):
621
+ """
622
+ Args:
623
+ x:
624
+ A 2-D tensor of shape (N, S).
625
+ x_lens:
626
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
627
+ before padding.
628
+ y:
629
+ A 3-D tensor of shape (N, K, T).
630
+ where K is the number of codebooks
631
+ y_lens:
632
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
633
+ before padding.
634
+ """
635
+ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
636
+ if len(x) == 0:
637
+ return None
638
+ x = x[
639
+ :, : x_lens.max()
640
+ ] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
641
+ y = y[:, :, : y_lens.max()]
642
+ assert x.ndim == 2, x.shape
643
+ assert x_lens.ndim == 1, x_lens.shape
644
+ assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
645
+ assert y_lens.ndim == 1, y_lens.shape
646
+ # makes attention mask and padding mask for x
647
+ x_padding_mask = make_pad_mask(x_lens).to(x.device)
648
+ x_attention_mask = (
649
+ torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
650
+ .bool()
651
+ .to(x_padding_mask.device)
652
+ )
653
+ x_input = self.text_embedding(x)
654
+ x_input = self.text_positional_embedding(x_input)
655
+ (
656
+ y_input,
657
+ new_y_lens,
658
+ targets,
659
+ y_padding_mask,
660
+ y_attention_mask,
661
+ mask_position,
662
+ patterns,
663
+ ) = self.prepare_input_target(y, y_lens)
664
+ y_out = self.dec_forward(
665
+ x_input,
666
+ x_lens,
667
+ x_attention_mask,
668
+ x_padding_mask,
669
+ y_input,
670
+ new_y_lens,
671
+ y_attention_mask,
672
+ y_padding_mask,
673
+ )
674
+ y_out = y_out[0] # no kv-caching during training
675
+ assert (
676
+ y_out.shape == y_input.shape
677
+ ), f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
678
+
679
+ logits = torch.stack(
680
+ [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1
681
+ ) # [B K S card]
682
+ # take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
683
+ assert (
684
+ logits.shape[1] == self.args.n_codebooks
685
+ and logits.shape[3] == self.n_audio_tokens[0]
686
+ ), logits.shape
687
+
688
+ logits_use = self.remove_mask(logits, mask_position, new_y_lens)
689
+
690
+ # revert the pattern shift for each logits section in each sample
691
+ logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
692
+ assert (
693
+ logits_final[0][0].shape[0] == self.args.n_codebooks
694
+ and logits_final[0][0].shape[2] == self.n_audio_tokens[0]
695
+ ), f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
696
+ # testing
697
+ sample_to_test = 0
698
+ assert len(logits_final[sample_to_test]) == len(
699
+ targets[sample_to_test]
700
+ ), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}"
701
+ temp = sum(
702
+ [
703
+ logits_final[sample_to_test][i].shape[:-1]
704
+ != targets[sample_to_test][i].shape
705
+ for i in range(len(targets[sample_to_test]))
706
+ ]
707
+ )
708
+ assert (
709
+ temp == 0
710
+ ), f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}"
711
+
712
+ logit_masked = sum(
713
+ [(item == False).any() for cur_mask in logit_masks for item in cur_mask]
714
+ )
715
+ assert logit_masked == 0, logit_masks
716
+
717
+ logits = torch.cat(
718
+ [torch.cat(item, dim=1) for item in logits_final], dim=1
719
+ ) # [K, T1+T2+T3+..., card]
720
+ targets = torch.cat(
721
+ [torch.cat(item, dim=1) for item in targets], dim=1
722
+ ) # [K, T1+T2+T3+...]
723
+ assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}"
724
+ loss = []
725
+ ntokens = []
726
+ top10acc = []
727
+ for k, (logit, target) in enumerate(zip(logits, targets)):
728
+ loss.append(F.cross_entropy(logit, target, reduction="mean"))
729
+ top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
730
+ ntokens.append(len(logit))
731
+
732
+ all_ntokens = sum(ntokens)
733
+ if self.args.codebook_weight != None:
734
+ codebook_weight = eval(self.args.codebook_weight)
735
+ else:
736
+ codebook_weight = [1.0] * self.args.n_codebooks
737
+ loss = sum([l * nt * cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
738
+ top10acc_by_codebook = [t10a * nt for t10a, nt in zip(top10acc, ntokens)]
739
+ top10acc = sum(top10acc_by_codebook)
740
+ ntokens = torch.tensor(all_ntokens).to(logits.device)
741
+
742
+ return {
743
+ "loss": loss,
744
+ "top10acc": top10acc,
745
+ "top10acc_by_codebook": top10acc_by_codebook,
746
+ "effective_ntoken": ntokens,
747
+ }
748
+
749
+ def inference(
750
+ self,
751
+ x: torch.Tensor,
752
+ x_lens: torch.Tensor,
753
+ y: torch.Tensor,
754
+ mask_interval: list[torch.Tensor],
755
+ top_k: int = -100,
756
+ top_p: float = 1.0,
757
+ temperature: float = 1.0,
758
+ stop_repetition: int = -1,
759
+ kvcache: int = 1,
760
+ silence_tokens: list[int] = [1388, 1898, 131],
761
+ ) -> torch.Tensor:
762
+ """
763
+ Args:
764
+ x:
765
+ A 2-D tensor of shape (1, L).
766
+ x_lens:
767
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
768
+ before padding.
769
+ y:
770
+ A 3-D tensor of shape (1, T, K).
771
+ mask_interval:
772
+ a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
773
+ top_k: (`optional`) int
774
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
775
+ top_p: (`optional`) float
776
+ For Neucleus sampling
777
+ temperature: (`optional`) float
778
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
779
+ eog_coef: (`optional`) float
780
+ if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length
781
+ stop_repetition (`optional`) int
782
+ if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
783
+ allowed_repeat_tokens (`optional`) list of ints
784
+ by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition
785
+ ultimate_stop_repetition (`optional`) int
786
+ no matter that token it is, stop repetition once after this number
787
+ """
788
+ assert x.ndim == 2, x.shape
789
+ assert x_lens.ndim == 1, x_lens.shape
790
+ assert y.ndim == 3, y.shape
791
+ if self.args.special_first:
792
+ y = y + int(self.args.n_special)
793
+ y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
794
+ assert (
795
+ y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
796
+ ), y.shape # there is no padding
797
+ assert mask_interval.shape == torch.Size(
798
+ (1, mask_interval.shape[1], 2)
799
+ ), mask_interval
800
+
801
+ # make x attention mask and x_input
802
+ x_attention_mask = (
803
+ torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
804
+ .bool()
805
+ .to(x.device)
806
+ )
807
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
808
+ x_input = self.text_embedding(x)
809
+ x_input = self.text_positional_embedding(x_input)
810
+
811
+ # make initial y_input
812
+
813
+ # make mask_interval and non_mask_interval
814
+ y_len = y.shape[2]
815
+ y_lens = torch.LongTensor([y_len]).to(y.device)
816
+ mask_interval = mask_interval[0]
817
+ starts = [item[0].item() for item in mask_interval] + [y_len]
818
+ ends = [0] + [item[1].item() for item in mask_interval]
819
+ mask_intervals = [
820
+ [(item[0].item(), item[1].item()) for item in mask_interval]
821
+ ] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
822
+ non_mask_intervals = [[(ns, ne) for ns, ne in zip(ends, starts)]]
823
+
824
+ # rearrange y
825
+ # will add have EOG in each section (SOG will be generated by the pattern class)
826
+ # but mask can be inserted later after we have shifted the input
827
+ # y could be rearranged in this way:
828
+ # [
829
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
830
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
831
+ # ...
832
+ # ]
833
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
834
+ # NOTE #non_masked_part = #masked_part + 1
835
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
836
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
837
+ 0
838
+ ].shape
839
+
840
+ # shift each element of y
841
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
842
+ # [
843
+ # [empty, 1, 2, 3, eog, empty, empty, empty],
844
+ # [empty, empty, 1, 2, 3, eog, empty, empty],
845
+ # [empty, empty, empty, 1, 2, 3, eog, empty],
846
+ # [empty, empty, empty, empty, 1, 2, 3, eog]
847
+ # ]
848
+ shifted_y, patterns = self.shift(
849
+ rearranged_y
850
+ ) # each element [K S], patterns is not used, as we directly use the original input y
851
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
852
+
853
+ # insert mask token at the intersction of each tensor, but *actually inserted eog as place holder*
854
+ # the position of inserted mask is also recorded
855
+ # and the mask_value, the index of the mask emb is recorded
856
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
857
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][
858
+ 0
859
+ ].shape[0]
860
+ assert inserted_y[0][1].shape == torch.Size(
861
+ (self.args.n_codebooks, 1)
862
+ ), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
863
+
864
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
865
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
866
+ assert cated_y.shape == torch.Size(
867
+ (self.args.n_codebooks, cated_y.shape[1], len(inserted_y))
868
+ )
869
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
870
+
871
+ ### NOTE this is different from forward, as we will remove the masked tokens
872
+ ### say there are two masked region
873
+ ### the cated_y should be like
874
+ ### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty]
875
+ ### which means we need to take the part after the last empty out
876
+ num_mask = len(mask_position[0]) // 2
877
+ assert num_mask == len(mask_position[0]) / 2, mask_position
878
+ cated_y = cated_y[:, : mask_position[0][num_mask] + 2] # of shape [K,T,B]
879
+ # logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}")
880
+ more_mask_value = mask_value[0][
881
+ num_mask + 1 :
882
+ ] # NOTE this will be used in the generation loop for reference for inserting mask embedding
883
+ new_y_lens[0] = mask_position[0][num_mask] + 2
884
+ mask_position[0] = mask_position[0][: num_mask + 1]
885
+ assert (
886
+ mask_position[0][num_mask] + 2 == cated_y.shape[1]
887
+ ), f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}"
888
+
889
+ # embed: remember to separately embed the mask tokens
890
+ embedded_y = self.embed_y(
891
+ cated_y, mask_position, [mask_value[0][: num_mask + 1]]
892
+ ) # BTD
893
+ # assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape
894
+
895
+ # positional embedding
896
+ y_input = self.audio_positional_embedding(embedded_y)
897
+
898
+ # make attention mask and padding mask
899
+ y_attention_mask = (
900
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
901
+ .bool()
902
+ .to(y.device)
903
+ )
904
+ # y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
905
+
906
+ x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
907
+ y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
908
+
909
+ codebook_eog = [False] * self.args.n_codebooks
910
+ generated = [] # doesn't contain any empty_token, contains eog
911
+ cur_generated = []
912
+ # say 0 is empty, 4 is eog
913
+ # tensor([[ 1, 2, 3, 4, 0, 0],
914
+ # [ 0, 1, 2, 3, 4, 0],
915
+ # [ 0, 0, 1, 2, 3, 4]])
916
+ num_gen = []
917
+ cur_num_gen = 0
918
+ ##################### silence repetition handling #####################
919
+ ##################### silence repetition handling #####################
920
+ logging.info(
921
+ f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
922
+ )
923
+ consec_silence_count = 0
924
+ prev_token = None
925
+ ##################### silence repetition handling #####################
926
+ ##################### silence repetition handling #####################
927
+ # prepare the cache placeholder
928
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
929
+ past = (
930
+ torch.ones(
931
+ [self.args.num_decoder_layers, 2, x.shape[0]],
932
+ device=x.device,
933
+ dtype=torch.float32,
934
+ )
935
+ if kvcache
936
+ else None
937
+ )
938
+ # handle multi-span kv-cache
939
+ new_masked_span = False
940
+
941
+ def sample_helper(
942
+ n_eog,
943
+ logits,
944
+ codebook_eog,
945
+ top_k,
946
+ top_p,
947
+ temperature,
948
+ prev_token,
949
+ consec_silence_count,
950
+ stop_repetition,
951
+ silence_tokens,
952
+ cur_num_gen,
953
+ ):
954
+ if n_eog == 0:
955
+ logits_adjust = logits
956
+ for jj in range(1, self.args.n_codebooks):
957
+ logits_adjust[jj][self.args.eog] = -10000
958
+ logits_adjust[jj][self.args.empty_token] = -10000
959
+ ##################### silence repetition handling #####################
960
+ if (
961
+ stop_repetition > 0
962
+ and prev_token in silence_tokens
963
+ and consec_silence_count > stop_repetition
964
+ ):
965
+ if logits_adjust[0, prev_token] < 0:
966
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (
967
+ consec_silence_count - (stop_repetition - 1)
968
+ )
969
+ else:
970
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (
971
+ consec_silence_count - (stop_repetition - 1)
972
+ )
973
+ ##################### silence repetition handling #####################
974
+ if type(logits_adjust) == list:
975
+ samples_list = []
976
+ for logit in logits_adjust:
977
+ # print(logit)
978
+ # print(logit.shape)
979
+ cur_sample = topk_sampling(
980
+ logit.unsqueeze(0),
981
+ top_k=top_k,
982
+ top_p=top_p,
983
+ temperature=temperature,
984
+ ) # [1, 1]
985
+ samples_list.append(cur_sample)
986
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
987
+ else:
988
+ samples = topk_sampling(
989
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
990
+ ) # [K, 1]
991
+ assert samples.shape == torch.Size(
992
+ (self.args.n_codebooks, 1)
993
+ ), f"samples.shape: {samples.shape}"
994
+ if cur_num_gen < self.args.n_codebooks - 1:
995
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
996
+ samples[-jj, 0] = self.args.empty_token
997
+
998
+ if (
999
+ samples[0, 0] == self.args.eog
1000
+ or torch.argmax(logits[0], dim=-1) == self.args.eog
1001
+ or y_input.shape[1] > x_lens[0] * 10
1002
+ ): # last one means y is already too long, shouldn't happen, but put it here
1003
+ samples[0, 0] = self.args.eog
1004
+ codebook_eog[0] = True
1005
+ ##################### silence repetition handling #####################
1006
+ ##################### silence repetition handling #####################
1007
+ if samples[0, 0] in silence_tokens and samples[0, 0] == prev_token:
1008
+ consec_silence_count += 1
1009
+ else:
1010
+ consec_silence_count = 0
1011
+ prev_token = samples[0, 0]
1012
+ ##################### silence repetition handling #####################
1013
+ ##################### silence repetition handling #####################
1014
+ return samples, codebook_eog, prev_token, consec_silence_count
1015
+ else:
1016
+ assert (
1017
+ sum(codebook_eog[i] for i in range(n_eog)) == n_eog
1018
+ ), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1019
+ logits_adjust = logits
1020
+ for jj in range(n_eog + 1, self.args.n_codebooks):
1021
+ logits_adjust[jj][self.args.eog] = -10000
1022
+ logits_adjust[jj][self.args.empty_token] = -10000
1023
+ if type(logits_adjust) == list:
1024
+ samples_list = []
1025
+ for logit in logits_adjust:
1026
+ cur_sample = topk_sampling(
1027
+ logit.unsqueeze(0),
1028
+ top_k=top_k,
1029
+ top_p=top_p,
1030
+ temperature=temperature,
1031
+ ) # [1, 1]
1032
+ samples_list.append(cur_sample)
1033
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
1034
+ else:
1035
+ samples = topk_sampling(
1036
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1037
+ ) # [K, 1]
1038
+ for jj in range(n_eog):
1039
+ samples[jj, 0] = self.args.empty_token
1040
+ samples[n_eog, 0] = self.args.eog
1041
+ codebook_eog[n_eog] = True
1042
+ return samples, codebook_eog, prev_token, consec_silence_count
1043
+
1044
+ while True:
1045
+ y_out, present = self.dec_forward(
1046
+ x_input,
1047
+ x_lens,
1048
+ x_attention_mask,
1049
+ x_padding_mask,
1050
+ y_input,
1051
+ new_y_lens,
1052
+ y_attention_mask,
1053
+ y_padding_mask,
1054
+ past=past,
1055
+ last_3_tokens=new_masked_span,
1056
+ )
1057
+ if new_masked_span:
1058
+ new_masked_span = False
1059
+
1060
+ if past != None:
1061
+ past = (
1062
+ torch.cat([past, present.to(past.dtype)], dim=-2)
1063
+ if past.ndim > 3
1064
+ else present.to(past.dtype)
1065
+ )
1066
+
1067
+ y_out = y_out[:, -1:] # only take the last one
1068
+
1069
+ logits = torch.stack(
1070
+ [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
1071
+ dim=1,
1072
+ ) # [B K S card], B==S==1, so [1 K 1 card]
1073
+ logits = logits.squeeze(0).squeeze(1) # [K card]
1074
+ assert logits.shape == torch.Size(
1075
+ (self.args.n_codebooks, self.n_audio_tokens[0])
1076
+ ), f"{logits.shape}"
1077
+
1078
+ n_eog = sum(codebook_eog)
1079
+ assert n_eog < self.args.n_codebooks
1080
+ if (
1081
+ self.args.eos > 0
1082
+ ): # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing
1083
+ for jj in range(self.args.n_codebooks):
1084
+ logits[jj][self.args.eos] = -10000.0
1085
+ # need to use a helper function to hand different n_eog cases
1086
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(
1087
+ n_eog,
1088
+ logits,
1089
+ codebook_eog,
1090
+ top_k,
1091
+ top_p,
1092
+ temperature,
1093
+ prev_token,
1094
+ consec_silence_count,
1095
+ stop_repetition,
1096
+ silence_tokens,
1097
+ cur_num_gen,
1098
+ )
1099
+ cur_num_gen += 1
1100
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
1101
+ # get samples_emb
1102
+ samples_emb = torch.stack(
1103
+ [
1104
+ self.audio_embedding[k](samples[k])
1105
+ for k in range(self.args.n_codebooks)
1106
+ ],
1107
+ dim=0,
1108
+ ) # [K,1,D]
1109
+ samples_emb = samples_emb.sum(dim=0, keepdim=True) # [1,1,D]
1110
+
1111
+ if (
1112
+ sum(codebook_eog) == self.args.n_codebooks
1113
+ ): # generation for the current span is done
1114
+ # re-init
1115
+ codebook_eog = [False] * self.args.n_codebooks
1116
+ num_gen.append(cur_num_gen)
1117
+ cur_num_gen = 0
1118
+ generated.append(cur_generated)
1119
+ cur_generated = []
1120
+
1121
+ # if the current mask span is the last span, then all done
1122
+ # else
1123
+ # append the next mask token and the four empty tokens to start the next generation
1124
+ if len(more_mask_value) > 0:
1125
+ next_mask_ind = more_mask_value.pop(0)
1126
+ mask_emb = (
1127
+ self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0)
1128
+ ) # [1,1,D]
1129
+ assert mask_emb.shape == torch.Size(
1130
+ (1, 1, self.args.d_model)
1131
+ ), mask_emb.shape
1132
+ empty_token = torch.LongTensor([self.args.empty_token]).to(y.device)
1133
+ empty_emb = torch.stack(
1134
+ [
1135
+ self.audio_embedding[k](empty_token)
1136
+ for k in range(self.args.n_codebooks)
1137
+ ],
1138
+ dim=0,
1139
+ ).sum(
1140
+ dim=0, keepdim=True
1141
+ ) # [1,1,D]
1142
+ assert empty_emb.shape == torch.Size(
1143
+ (1, 1, self.args.d_model)
1144
+ ), empty_emb.shape
1145
+ extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D]
1146
+ samples_emb = torch.cat(
1147
+ [samples_emb, extra_emb], dim=1
1148
+ ) # [1,3,D] # prev_last_token, mask_token, empty token
1149
+ assert samples_emb.shape == torch.Size(
1150
+ (1, 3, self.args.d_model)
1151
+ ), f"samples_emb.shape: {samples_emb.shape}"
1152
+ ##################### silence repetition handling #####################
1153
+ ##################### silence repetition handling #####################
1154
+ consec_silence_count = 0
1155
+ prev_token = None
1156
+ ##################### silence repetition handling #####################
1157
+ ##################### silence repetition handling #####################
1158
+
1159
+ # handling kv-caching for multi-span editing
1160
+ new_masked_span = True
1161
+ else:
1162
+ break
1163
+ else:
1164
+ assert samples_emb.shape == torch.Size(
1165
+ (1, 1, self.args.d_model)
1166
+ ), f"samples_emb.shape: {samples_emb.shape}"
1167
+
1168
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1169
+ # positional embedding
1170
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1171
+ # make attention mask and padding mask
1172
+ y_attention_mask = (
1173
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
1174
+ .bool()
1175
+ .to(y.device)
1176
+ )
1177
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
1178
+ y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
1179
+
1180
+ assert (
1181
+ len(generated) == num_mask
1182
+ ), f"len(generated): {len(generated)}, num_mask: {num_mask}"
1183
+
1184
+ # # combine non_masked_span with generated spans
1185
+ # first need to shift the generated part back
1186
+ flatten_gen = []
1187
+ for l, orig_span in enumerate(generated):
1188
+ span = torch.stack(orig_span, dim=0) # [T K]
1189
+ span = span.transpose(1, 0) # [K, T]
1190
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1191
+ unshifted_span = []
1192
+ for j, s in enumerate(span):
1193
+ start_from = j
1194
+ end_at = -(self.args.n_codebooks - start_from)
1195
+ unshifted_span.append(s[start_from:end_at])
1196
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1197
+
1198
+ assert (
1199
+ unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
1200
+ ), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1201
+ flatten_gen.append(unshifted_span)
1202
+ # logging.info(f"unshfited_span: {unshifted_span.shape}")
1203
+ # raise
1204
+ assert len(non_mask_intervals[0]) - 1 == len(
1205
+ flatten_gen
1206
+ ), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}"
1207
+ res = []
1208
+ for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen):
1209
+ res.append(y[0, :, orig_interval[0] : orig_interval[1]])
1210
+ res.append(gen)
1211
+ res.append(y[0, :, non_mask_intervals[0][-1][0] : non_mask_intervals[0][-1][1]])
1212
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
1213
+
1214
+ expected_y_len = (
1215
+ y_len
1216
+ - sum([item[1] - item[0] for item in mask_intervals[0]])
1217
+ + sum([item - self.args.n_codebooks for item in num_gen])
1218
+ )
1219
+ assert res.shape == torch.Size(
1220
+ (1, self.args.n_codebooks, expected_y_len)
1221
+ ), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1222
+
1223
+ if self.args.special_first:
1224
+ res = res - int(self.args.n_special)
1225
+
1226
+ return res
1227
+
1228
+ def inference_tts(
1229
+ self,
1230
+ x: torch.Tensor,
1231
+ x_lens: torch.Tensor,
1232
+ y: torch.Tensor,
1233
+ top_k: int = -100,
1234
+ top_p: float = 1.0,
1235
+ temperature: float = 1.0,
1236
+ stop_repetition: int = 3,
1237
+ kvcache: int = 1,
1238
+ silence_tokens: list[int] = [1388, 1898, 131],
1239
+ *kargs,
1240
+ ) -> torch.Tensor:
1241
+ """
1242
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
1243
+ Args:
1244
+ x:
1245
+ A 2-D tensor of shape (1, L).
1246
+ x_lens:
1247
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1248
+ before padding.
1249
+ y:
1250
+ A 3-D tensor of shape (1, T, K).
1251
+ top_k: (`optional`) int
1252
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
1253
+ top_p: (`optional`) float
1254
+ For Neucleus sampling
1255
+ temperature: (`optional`) float
1256
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
1257
+ """
1258
+ eog_inference = self.args.eos if self.args.eos > 0 else self.args.eog
1259
+ assert x.ndim == 2, x.shape
1260
+ assert x_lens.ndim == 1, x_lens.shape
1261
+ assert y.ndim == 3, y.shape
1262
+ if self.args.special_first:
1263
+ y = y + int(self.args.n_special)
1264
+ y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
1265
+ assert (
1266
+ y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
1267
+ ), y.shape # there is no padding
1268
+
1269
+ # make x attention mask and x_input
1270
+ x_attention_mask = (
1271
+ torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
1272
+ .bool()
1273
+ .to(x.device)
1274
+ )
1275
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
1276
+ x_input = self.text_embedding(x)
1277
+ x_input = self.text_positional_embedding(x_input)
1278
+
1279
+ y_len = y.shape[2]
1280
+ y_lens = torch.LongTensor([y_len]).to(y.device)
1281
+
1282
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
1283
+ rearranged_y = [[y[0]]]
1284
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
1285
+ 0
1286
+ ].shape
1287
+
1288
+ # shift y to create the delayed pattern
1289
+ shifted_y, patterns = self.shift(
1290
+ rearranged_y
1291
+ ) # each element [K S], patterns is not used, as we directly use the original input y
1292
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
1293
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
1294
+
1295
+ # below is different from forward or inference
1296
+ # where we cut this shifted part
1297
+ shifted_y[0][0] = shifted_y[0][0][:, : -(self.args.n_codebooks - 1)]
1298
+ assert (
1299
+ not (
1300
+ shifted_y[0][0][self.args.n_codebooks :] == self.args.empty_token
1301
+ ).any()
1302
+ and not (shifted_y[0][0][self.args.n_codebooks :] == self.args.eog).any()
1303
+ ), shifted_y[0][0]
1304
+
1305
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
1306
+ # next section is concate tensors of each sample to one tensor, which we also don't need
1307
+ cated_y = shifted_y[0][0].unsqueeze(-1) # [K,S]->[K,S,B]
1308
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
1309
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
1310
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
1311
+
1312
+ # replace tokens in y with the embeddings, add sum codebooks up
1313
+ embedded_y = torch.stack(
1314
+ [self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
1315
+ dim=0,
1316
+ ) # [K, S, B, D]
1317
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
1318
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
1319
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
1320
+ embedded_y = embedded_y.transpose(1, 0) # [S,B,D]->[B,S,D]
1321
+
1322
+ # positional embedding
1323
+ y_input = self.audio_positional_embedding(embedded_y)
1324
+
1325
+ # make attention mask and padding mask
1326
+ y_attention_mask = (
1327
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
1328
+ .bool()
1329
+ .to(y.device)
1330
+ )
1331
+
1332
+ x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
1333
+ y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
1334
+
1335
+ # entering the generation stage
1336
+ # starting from line 708
1337
+ codebook_eog = [False] * self.args.n_codebooks
1338
+ generated = [] # doesn't contain any empty token, contain eog
1339
+ cur_generated = []
1340
+ # say 0 is empty, 4 is eog
1341
+ # tensor([[ 1, 2, 3, 4, 0, 0],
1342
+ # [ 0, 1, 2, 3, 4, 0],
1343
+ # [ 0, 0, 1, 2, 3, 4]])
1344
+ num_gen = []
1345
+ cur_num_gen = 0
1346
+ ##################### silence repetition handling #####################
1347
+ ##################### silence repetition handling #####################
1348
+ logging.info(
1349
+ f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
1350
+ )
1351
+ consec_silence_count = 0
1352
+ prev_token = None
1353
+ ##################### silence repetition handling #####################
1354
+ ##################### silence repetition handling #####################
1355
+
1356
+ # prepare the cache placeholder
1357
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
1358
+ past = (
1359
+ torch.ones(
1360
+ [self.args.num_decoder_layers, 2, x.shape[0]],
1361
+ device=x.device,
1362
+ dtype=torch.float32,
1363
+ )
1364
+ if kvcache
1365
+ else None
1366
+ )
1367
+
1368
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1369
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1370
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1371
+ def sample_helper(
1372
+ n_eog,
1373
+ logits,
1374
+ codebook_eog,
1375
+ top_k,
1376
+ top_p,
1377
+ temperature,
1378
+ prev_token,
1379
+ consec_silence_count,
1380
+ stop_repetition,
1381
+ silence_tokens,
1382
+ cur_num_gen,
1383
+ ):
1384
+ if n_eog == 0:
1385
+ logits_adjust = logits
1386
+ for jj in range(1, self.args.n_codebooks):
1387
+ logits_adjust[jj][eog_inference] = -10000
1388
+ logits_adjust[jj][self.args.empty_token] = -10000
1389
+ if (
1390
+ cur_num_gen <= self.args.encodec_sr // 5
1391
+ ): # this shouldn't happen, but just in case the model stopped too early
1392
+ logits_adjust[0][eog_inference] = -10000
1393
+ ##################### silence repetition handling #####################
1394
+ if (
1395
+ stop_repetition > 0
1396
+ and prev_token in silence_tokens
1397
+ and consec_silence_count > stop_repetition
1398
+ ):
1399
+ if logits_adjust[0, prev_token] < 0:
1400
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (
1401
+ consec_silence_count - (stop_repetition - 1)
1402
+ )
1403
+ else:
1404
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (
1405
+ consec_silence_count - (stop_repetition - 1)
1406
+ )
1407
+ ##################### silence repetition handling #####################
1408
+ samples = topk_sampling(
1409
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1410
+ ) # [K, 1]
1411
+ assert samples.shape == torch.Size(
1412
+ (self.args.n_codebooks, 1)
1413
+ ), f"samples.shape: {samples.shape}"
1414
+ if cur_num_gen < self.args.n_codebooks - 1:
1415
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1416
+ samples[-jj, 0] = self.args.empty_token
1417
+
1418
+ if (
1419
+ samples[0, 0] == eog_inference
1420
+ or torch.argmax(logits[0], dim=-1) == eog_inference
1421
+ or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr // 5)
1422
+ ): # last one means y is already too long, shouldn't happen, but put it here
1423
+ samples[0, 0] = eog_inference
1424
+ codebook_eog[0] = True
1425
+ ##################### silence repetition handling #####################
1426
+ if samples[0, 0] in silence_tokens and samples[0, 0] == prev_token:
1427
+ consec_silence_count += 1
1428
+ else:
1429
+ consec_silence_count = 0
1430
+ prev_token = samples[0, 0]
1431
+ ##################### silence repetition handling #####################
1432
+ return samples, codebook_eog, prev_token, consec_silence_count
1433
+ else:
1434
+ assert (
1435
+ sum(codebook_eog[i] for i in range(n_eog)) == n_eog
1436
+ ), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1437
+ logits_adjust = logits
1438
+ for jj in range(n_eog + 1, self.args.n_codebooks):
1439
+ logits_adjust[jj][eog_inference] = -10000
1440
+ logits_adjust[jj][self.args.empty_token] = -10000
1441
+ samples = topk_sampling(
1442
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1443
+ ) # [K, 1]
1444
+ for jj in range(n_eog):
1445
+ samples[jj, 0] = self.args.empty_token
1446
+ samples[n_eog, 0] = eog_inference
1447
+ codebook_eog[n_eog] = True
1448
+ return samples, codebook_eog, prev_token, consec_silence_count
1449
+
1450
+ while True:
1451
+ y_out, present = self.dec_forward(
1452
+ x_input,
1453
+ x_lens,
1454
+ x_attention_mask,
1455
+ x_padding_mask,
1456
+ y_input,
1457
+ new_y_lens,
1458
+ y_attention_mask,
1459
+ y_padding_mask,
1460
+ past=past,
1461
+ )
1462
+ if past != None:
1463
+ past = (
1464
+ torch.cat([past, present.to(past.dtype)], dim=-2)
1465
+ if past.ndim > 3
1466
+ else present.to(past.dtype)
1467
+ )
1468
+
1469
+ y_out = y_out[:, -1:] # only take the last token
1470
+ logits = torch.stack(
1471
+ [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
1472
+ dim=1,
1473
+ ) # [B K S card], B==S==1, so [1 K 1 card]
1474
+ logits = logits.squeeze(0).squeeze(1) # [K card]
1475
+ assert logits.shape == torch.Size(
1476
+ (self.args.n_codebooks, self.n_audio_tokens[0])
1477
+ ), f"{logits.shape}"
1478
+
1479
+ n_eog = sum(codebook_eog)
1480
+ assert n_eog < self.args.n_codebooks
1481
+ if (
1482
+ self.args.eos > 0
1483
+ ): # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
1484
+ for jj in range(self.args.n_codebooks):
1485
+ logits[jj][self.args.eog] = -10000.0
1486
+
1487
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(
1488
+ n_eog,
1489
+ logits,
1490
+ codebook_eog,
1491
+ top_k,
1492
+ top_p,
1493
+ temperature,
1494
+ prev_token,
1495
+ consec_silence_count,
1496
+ stop_repetition,
1497
+ silence_tokens,
1498
+ cur_num_gen,
1499
+ )
1500
+
1501
+ cur_num_gen += 1
1502
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
1503
+
1504
+ # samples.shape is [K,1]
1505
+ # ge samples_emb
1506
+ samples_emb = torch.stack(
1507
+ [
1508
+ self.audio_embedding[k](samples[k])
1509
+ for k in range(self.args.n_codebooks)
1510
+ ],
1511
+ dim=0,
1512
+ ) # [K,1,D]
1513
+ samples_emb = samples_emb.sum(dim=0, keepdim=True) # [1,1,D]
1514
+
1515
+ if (
1516
+ sum(codebook_eog) == self.args.n_codebooks
1517
+ ): # generation for the current span is done
1518
+ codebook_eog = [False] * self.args.n_codebooks
1519
+ num_gen.append(cur_num_gen)
1520
+ cur_num_gen = 0
1521
+ generated.append(cur_generated)
1522
+ cur_generated = []
1523
+ break
1524
+ else:
1525
+ assert samples_emb.shape == torch.Size(
1526
+ (1, 1, self.args.d_model)
1527
+ ), f"samples_emb.shape: {samples_emb.shape}"
1528
+
1529
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1530
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1531
+ # make attention mask and padding mask
1532
+ y_attention_mask = (
1533
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
1534
+ .bool()
1535
+ .to(y.device)
1536
+ )
1537
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
1538
+ y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
1539
+
1540
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1541
+
1542
+ # revert the pattern
1543
+ flatten_gen = []
1544
+ for l, orig_span in enumerate(generated):
1545
+ span = torch.stack(orig_span, dim=0) # [T, K]
1546
+ span = span.transpose(1, 0) # [K, T]
1547
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1548
+ unshifted_span = []
1549
+ for j, s in enumerate(span):
1550
+ start_from = j
1551
+ end_at = -(self.args.n_codebooks - start_from)
1552
+ unshifted_span.append(s[start_from:end_at])
1553
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1554
+
1555
+ assert (
1556
+ unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
1557
+ ), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1558
+
1559
+ flatten_gen.append(unshifted_span)
1560
+ assert len(flatten_gen) == 1, len(flatten_gen)
1561
+
1562
+ # combine
1563
+ res = [y[0], flatten_gen[0]]
1564
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1565
+
1566
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1567
+ assert res.shape == torch.Size(
1568
+ (1, self.args.n_codebooks, expected_y_len)
1569
+ ), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1570
+
1571
+ if self.args.special_first:
1572
+ res = res - int(self.args.n_special)
1573
+ flatten_gen = flatten_gen - int(self.args.n_special)
1574
+
1575
+ return res, flatten_gen[0].unsqueeze(0)
1576
+
1577
+ def inference_tts_batch(
1578
+ self,
1579
+ x: torch.Tensor,
1580
+ x_lens: torch.Tensor,
1581
+ y: torch.Tensor,
1582
+ top_k: int = -100,
1583
+ top_p: float = 1.0,
1584
+ temperature: float = 1.0,
1585
+ stop_repetition: int = 3,
1586
+ kvcache: int = 1,
1587
+ batch_size: int = 5,
1588
+ silence_tokens: list[int] = [1388, 1898, 131],
1589
+ *kargs,
1590
+ ) -> torch.Tensor:
1591
+ """
1592
+ have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes
1593
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
1594
+ Args:
1595
+ x:
1596
+ A 2-D tensor of shape (1, L).
1597
+ x_lens:
1598
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1599
+ before padding.
1600
+ y:
1601
+ A 3-D tensor of shape (1, T, K).
1602
+ top_k: (`optional`) int
1603
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
1604
+ top_p: (`optional`) float
1605
+ For Neucleus sampling
1606
+ temperature: (`optional`) float
1607
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
1608
+ """
1609
+ eog_inference = self.args.eos if self.args.eos > 0 else self.args.eog
1610
+ assert x.ndim == 2, x.shape
1611
+ assert x_lens.ndim == 1, x_lens.shape
1612
+ assert y.ndim == 3, y.shape
1613
+ if self.args.special_first:
1614
+ y = y + int(self.args.n_special)
1615
+ y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
1616
+ assert (
1617
+ y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
1618
+ ), y.shape # there is no padding
1619
+
1620
+ # make x attention mask and x_input
1621
+ x_attention_mask = (
1622
+ torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
1623
+ .bool()
1624
+ .to(x.device)
1625
+ )
1626
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
1627
+ x_input = self.text_embedding(x)
1628
+ x_input = self.text_positional_embedding(x_input)
1629
+
1630
+ y_len = y.shape[2]
1631
+ y_lens = torch.LongTensor([y_len]).to(y.device)
1632
+
1633
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
1634
+ rearranged_y = [[y[0]]]
1635
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
1636
+ 0
1637
+ ].shape
1638
+
1639
+ # shift y to create the delayed pattern
1640
+ shifted_y, patterns = self.shift(
1641
+ rearranged_y
1642
+ ) # each element [K S], patterns is not used, as we directly use the original input y
1643
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
1644
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
1645
+
1646
+ # below is different from forward or inference
1647
+ # where we cut this shifted part
1648
+ shifted_y[0][0] = shifted_y[0][0][:, : -(self.args.n_codebooks - 1)]
1649
+ assert (
1650
+ not (
1651
+ shifted_y[0][0][self.args.n_codebooks :] == self.args.empty_token
1652
+ ).any()
1653
+ and not (shifted_y[0][0][self.args.n_codebooks :] == self.args.eog).any()
1654
+ ), shifted_y[0][0]
1655
+
1656
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
1657
+ # next section is concate tensors of each sample to one tensor, which we also don't need
1658
+ cated_y = shifted_y[0][0].unsqueeze(-1) # [K,S]->[K,S,B]
1659
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
1660
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
1661
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
1662
+
1663
+ # replace tokens in y with the embeddings, add sum codebooks up
1664
+ embedded_y = torch.stack(
1665
+ [self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
1666
+ dim=0,
1667
+ ) # [K, S, B, D]
1668
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
1669
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
1670
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
1671
+ embedded_y = embedded_y.transpose(1, 0) # [S,B,D]->[B,S,D]
1672
+
1673
+ # positional embedding
1674
+ y_input = self.audio_positional_embedding(embedded_y)
1675
+
1676
+ # make attention mask and padding mask
1677
+ y_attention_mask = (
1678
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
1679
+ .bool()
1680
+ .to(y.device)
1681
+ )
1682
+
1683
+ x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
1684
+ y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
1685
+
1686
+ # entering the generation stage
1687
+ # starting from line 708
1688
+ codebook_eog = [False] * self.args.n_codebooks
1689
+ generated = [] # doesn't contain any empty token, contain eog
1690
+ cur_generated = [[] for _ in range(batch_size)]
1691
+ # say 0 is empty, 4 is eog
1692
+ # tensor([[ 1, 2, 3, 4, 0, 0],
1693
+ # [ 0, 1, 2, 3, 4, 0],
1694
+ # [ 0, 0, 1, 2, 3, 4]])
1695
+ num_gen = []
1696
+ cur_num_gen = 0
1697
+ ##################### silence repetition handling #####################
1698
+ ##################### silence repetition handling #####################
1699
+ logging.info(
1700
+ f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
1701
+ )
1702
+ consec_silence_counts = [0 for _ in range(batch_size)]
1703
+ prev_tokens = [None for _ in range(batch_size)]
1704
+ ##################### silence repetition handling #####################
1705
+ ##################### silence repetition handling #####################
1706
+
1707
+ # prepare the cache placeholder
1708
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
1709
+ past = (
1710
+ torch.ones(
1711
+ [self.args.num_decoder_layers, 2, x.shape[0]],
1712
+ device=x.device,
1713
+ dtype=torch.float32,
1714
+ )
1715
+ if kvcache
1716
+ else None
1717
+ )
1718
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1719
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1720
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1721
+ keep = None # NOTE: this very important, tells which sample to keep
1722
+
1723
+ def sample_helper(
1724
+ n_eog,
1725
+ logits,
1726
+ codebook_eog,
1727
+ top_k,
1728
+ top_p,
1729
+ temperature,
1730
+ prev_tokens,
1731
+ consec_silence_counts,
1732
+ stop_repetition,
1733
+ silence_tokens,
1734
+ cur_num_gen,
1735
+ keep,
1736
+ ):
1737
+ if n_eog == 0:
1738
+ logits_adjust = logits
1739
+ for jj in range(1, self.args.n_codebooks):
1740
+ logits_adjust[:, jj, eog_inference] = -10000
1741
+ logits_adjust[:, jj, self.args.empty_token] = -10000
1742
+ if (
1743
+ cur_num_gen <= self.args.encodec_sr // 5
1744
+ ): # this shouldn't happen, but just in case the model stopped too early
1745
+ logits_adjust[:, :, eog_inference] = -10000
1746
+ ##################### silence repetition handling #####################
1747
+ for b in range(batch_size):
1748
+ prev_token = prev_tokens[b]
1749
+ consec_silence_count = consec_silence_counts[b]
1750
+ if (
1751
+ stop_repetition > 0
1752
+ and prev_token in silence_tokens
1753
+ and consec_silence_count > stop_repetition
1754
+ ):
1755
+ if logits_adjust[b, 0, prev_token] < 0:
1756
+ logits_adjust[b, 0, prev_token] = logits_adjust[
1757
+ b, 0, prev_token
1758
+ ] * (consec_silence_count - (stop_repetition - 1))
1759
+ else:
1760
+ logits_adjust[b, 0, prev_token] = logits_adjust[
1761
+ b, 0, prev_token
1762
+ ] / (consec_silence_count - (stop_repetition - 1))
1763
+ ##################### silence repetition handling #####################
1764
+ samples = topk_sampling(
1765
+ logits_adjust.reshape(
1766
+ batch_size * self.args.n_codebooks, logits_adjust.shape[-1]
1767
+ ),
1768
+ top_k=top_k,
1769
+ top_p=top_p,
1770
+ temperature=temperature,
1771
+ ) # [B*K, 1]
1772
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1773
+ assert samples.shape == torch.Size(
1774
+ (batch_size, self.args.n_codebooks, 1)
1775
+ ), f"samples.shape: {samples.shape}"
1776
+ for b in range(batch_size):
1777
+ if cur_num_gen < self.args.n_codebooks - 1:
1778
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1779
+ samples[b, -jj, 0] = self.args.empty_token
1780
+
1781
+ if (
1782
+ samples[b, 0, 0] == eog_inference
1783
+ or torch.argmax(logits[b, 0], dim=-1) == eog_inference
1784
+ or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr // 5)
1785
+ ): # last one means y is already too long, shouldn't happen, but put it here
1786
+ samples[b, 0, 0] = eog_inference
1787
+ codebook_eog[0] = True
1788
+ keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one)
1789
+ ##################### silence repetition handling #####################
1790
+ if (
1791
+ samples[b, 0, 0] in silence_tokens
1792
+ and samples[b, 0, 0] == prev_tokens[b]
1793
+ ):
1794
+ consec_silence_counts[b] += 1
1795
+ else:
1796
+ consec_silence_counts[b] = 0
1797
+ prev_tokens[b] = samples[b, 0, 0]
1798
+ ##################### silence repetition handling #####################
1799
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1800
+ else:
1801
+ assert (
1802
+ sum(codebook_eog[i] for i in range(n_eog)) == n_eog
1803
+ ), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1804
+ logits_adjust = logits
1805
+ for jj in range(n_eog + 1, self.args.n_codebooks):
1806
+ logits_adjust[:, jj, eog_inference] = -10000
1807
+ logits_adjust[:, jj, self.args.empty_token] = -10000
1808
+ samples = topk_sampling(
1809
+ logits_adjust.reshape(
1810
+ batch_size * self.args.n_codebooks, logits_adjust.shape[-1]
1811
+ ),
1812
+ top_k=top_k,
1813
+ top_p=top_p,
1814
+ temperature=temperature,
1815
+ ) # [B, K, 1]
1816
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1817
+ for jj in range(n_eog):
1818
+ samples[keep, jj, 0] = self.args.empty_token
1819
+ samples[keep, n_eog, 0] = eog_inference
1820
+ codebook_eog[n_eog] = True
1821
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1822
+
1823
+ while True:
1824
+ # if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token
1825
+ # in the first generation step, we repeat each tensor to make their first dimension of length the batch size
1826
+ if cur_num_gen == 0:
1827
+ assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape
1828
+ assert (
1829
+ x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1
1830
+ ), x_padding_mask.shape
1831
+ assert (
1832
+ y_input.ndim == 3
1833
+ and y_input.shape[0] == 1
1834
+ and y_input.shape[1] == new_y_lens[0]
1835
+ ), y_input.shape
1836
+ assert (
1837
+ embedded_y.ndim == 3
1838
+ and embedded_y.shape[0] == 1
1839
+ and embedded_y.shape[1] == new_y_lens[0]
1840
+ ), embedded_y.shape
1841
+ x_input = x_input.repeat(batch_size, 1, 1)
1842
+ x_lens = x_lens.repeat(batch_size)
1843
+ # x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1844
+ x_padding_mask = x_padding_mask.repeat(batch_size, 1)
1845
+ y_input = y_input.repeat(batch_size, 1, 1)
1846
+ new_y_lens = new_y_lens.repeat(batch_size)
1847
+ # y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1848
+ y_padding_mask = y_padding_mask.repeat(batch_size, 1)
1849
+ embedded_y = embedded_y.repeat(
1850
+ batch_size, 1, 1
1851
+ ) # will be used to concat with newly generated token embedding
1852
+ past = past.repeat(1, 1, batch_size) if past != None else None
1853
+ else:
1854
+ assert (
1855
+ x_input.shape[0] == batch_size
1856
+ and x_padding_mask.shape[0] == batch_size
1857
+ and y_input.shape[0] == batch_size
1858
+ and new_y_lens.shape[0] == batch_size
1859
+ ), f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
1860
+ y_out, present = self.dec_forward(
1861
+ x_input,
1862
+ x_lens,
1863
+ x_attention_mask,
1864
+ x_padding_mask,
1865
+ y_input,
1866
+ new_y_lens,
1867
+ y_attention_mask,
1868
+ y_padding_mask,
1869
+ past=past,
1870
+ )
1871
+ if past != None:
1872
+ past = (
1873
+ torch.cat([past, present.to(past.dtype)], dim=-2)
1874
+ if past.ndim > 3
1875
+ else present.to(past.dtype)
1876
+ )
1877
+
1878
+ # if no eog emerges, y_out should have batch size of batch_size
1879
+ if sum(codebook_eog) == 0:
1880
+ assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape
1881
+ y_out = y_out[:, -1:] # only take the last token
1882
+ logits = torch.stack(
1883
+ [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
1884
+ dim=1,
1885
+ ) # [B K S card], S==1, so [B K 1 card]
1886
+ logits = logits.squeeze(2) # [B K card]
1887
+ assert logits.shape == torch.Size(
1888
+ (batch_size, self.args.n_codebooks, self.n_audio_tokens[0])
1889
+ ), f"{logits.shape}"
1890
+
1891
+ n_eog = sum(codebook_eog)
1892
+ if self.args.eos > 0:
1893
+ for jj in range(self.args.n_codebooks):
1894
+ logits[:, jj, self.args.eog] = -10000.0
1895
+ samples, codebook_eog, prev_tokens, consec_silence_counts, keep = (
1896
+ sample_helper(
1897
+ n_eog,
1898
+ logits,
1899
+ codebook_eog,
1900
+ top_k,
1901
+ top_p,
1902
+ temperature,
1903
+ prev_tokens,
1904
+ consec_silence_counts,
1905
+ stop_repetition,
1906
+ silence_tokens,
1907
+ cur_num_gen,
1908
+ keep,
1909
+ )
1910
+ )
1911
+
1912
+ cur_num_gen += 1
1913
+ if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
1914
+ assert keep == None
1915
+ for b in range(batch_size):
1916
+ cur_generated[b].append(samples[b].squeeze(-1))
1917
+ elif sum(codebook_eog) == 1: # the first eog just showed up in this step
1918
+ assert keep != None
1919
+ cur_generated = cur_generated[keep]
1920
+ cur_generated.append(samples[keep].squeeze(-1))
1921
+ else: # we are generating the rest eogs for the 'keep' sample
1922
+ cur_generated.append(samples[keep].squeeze(-1))
1923
+
1924
+ # samples.shape is [K,1]
1925
+ # ge samples_emb
1926
+ samples_emb = torch.stack(
1927
+ [
1928
+ self.audio_embedding[k](samples[:, k])
1929
+ for k in range(self.args.n_codebooks)
1930
+ ],
1931
+ dim=1,
1932
+ ) # [B, K,1,D]
1933
+ assert samples_emb.shape == torch.Size(
1934
+ [batch_size, self.args.n_codebooks, 1, self.args.d_model]
1935
+ )
1936
+ samples_emb = samples_emb.sum(dim=1, keepdim=False) # [B,1,D]
1937
+ if (
1938
+ sum(codebook_eog) == self.args.n_codebooks
1939
+ ): # generation for the current span is done
1940
+ codebook_eog = [False] * self.args.n_codebooks
1941
+ num_gen.append(cur_num_gen)
1942
+ cur_num_gen = 0
1943
+ generated.append(cur_generated)
1944
+ cur_generated = [[] for _ in range(batch_size)]
1945
+ break
1946
+ else:
1947
+ assert samples_emb.shape == torch.Size(
1948
+ (batch_size, 1, self.args.d_model)
1949
+ ), f"samples_emb.shape: {samples_emb.shape}"
1950
+
1951
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1952
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1953
+ # make attention mask and padding mask
1954
+ y_attention_mask = (
1955
+ torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
1956
+ .bool()
1957
+ .to(y.device)
1958
+ )
1959
+ new_y_lens = (
1960
+ torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
1961
+ )
1962
+ y_padding_mask = torch.full((batch_size, new_y_lens[0]), False).to(y.device)
1963
+
1964
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1965
+
1966
+ # revert the pattern
1967
+ flatten_gen = []
1968
+ for l, orig_span in enumerate(generated):
1969
+ span = torch.stack(orig_span, dim=0) # [T, K]
1970
+ span = span.transpose(1, 0) # [K, T]
1971
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1972
+ unshifted_span = []
1973
+ for j, s in enumerate(span):
1974
+ start_from = j
1975
+ end_at = -(self.args.n_codebooks - start_from)
1976
+ unshifted_span.append(s[start_from:end_at])
1977
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1978
+
1979
+ assert (
1980
+ unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
1981
+ ), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1982
+
1983
+ flatten_gen.append(unshifted_span)
1984
+ assert len(flatten_gen) == 1, len(flatten_gen)
1985
+
1986
+ # combine
1987
+ res = [y[0], flatten_gen[0]]
1988
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1989
+
1990
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1991
+ assert res.shape == torch.Size(
1992
+ (1, self.args.n_codebooks, expected_y_len)
1993
+ ), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1994
+
1995
+ if self.args.special_first:
1996
+ res = res - int(self.args.n_special)
1997
+ flatten_gen = flatten_gen - int(self.args.n_special)
1998
+
1999
+ return res, flatten_gen[0].unsqueeze(0)
src/model/modules/voicecraftconfig.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class VoiceCraftConfig:
2
+
3
+ def __init__(
4
+ self,
5
+ model_name="330M_TTSEnhanced.pth", # "gigaHalfLibri330M_TTSEnhanced_max16s.pth",
6
+ encodec="encodec_4cb2048_giga.th",
7
+ top_k=0,
8
+ top_p=0.9,
9
+ temperature=1,
10
+ kvcache=1,
11
+ codec_sr=50,
12
+ codec_audio_sr=16000,
13
+ silence_tokens=[1388, 1898, 131],
14
+ stop_repetition=3,
15
+ sample_batch_size=2,
16
+ seed=1,
17
+ cut_off_sec=7.87,
18
+ voice_audio_path="84_121550_000074_000000.wav",
19
+ voice_audio_transcript="But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks",
20
+ **kwargs,
21
+ ):
22
+ super().__init__()
23
+ self.model_name = model_name
24
+ self.encodec = encodec
25
+ self.top_k = top_k
26
+ self.top_p = top_p
27
+ self.temperature = temperature
28
+ self.kvcache = kvcache
29
+ self.codec_sr = codec_sr
30
+ self.codec_audio_sr = codec_audio_sr
31
+ self.silence_tokens = silence_tokens
32
+ self.stop_repetition = stop_repetition
33
+ self.sample_batch_size = sample_batch_size
34
+ self.seed = seed
35
+ self.cut_off_sec = cut_off_sec
36
+ self.voice_audio_path = voice_audio_path
37
+ self.voice_audio_transcript = voice_audio_transcript
src/utils/__init__.py ADDED
File without changes
src/utils/image_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import logging
18
+ import os
19
+ from io import BytesIO
20
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
21
+
22
+ import PIL
23
+ import numpy as np
24
+ import requests
25
+ from packaging import version
26
+
27
+
28
+ def _is_numpy(x):
29
+ return isinstance(x, np.ndarray)
30
+
31
+
32
+ def is_numpy_array(x):
33
+ """
34
+ Tests if `x` is a numpy array or not.
35
+ """
36
+ return _is_numpy(x)
37
+
38
+
39
+ def is_pil_image(img):
40
+ return isinstance(img, PIL.Image.Image)
41
+
42
+
43
+ def is_valid_image(img):
44
+ return is_pil_image(img) or is_numpy_array(img)
45
+
46
+
47
+ def valid_images(imgs):
48
+ # If we have an list of images, make sure every image is valid
49
+ if isinstance(imgs, (list, tuple)):
50
+ for img in imgs:
51
+ if not valid_images(img):
52
+ return False
53
+ # If not a list of tuple, we have been given a single image or batched tensor of images
54
+ elif not is_valid_image(imgs):
55
+ return False
56
+ return True
57
+
58
+
59
+ def is_batched(img):
60
+ if isinstance(img, (list, tuple)):
61
+ return is_valid_image(img[0])
62
+ return False
63
+
64
+
65
+ def is_scaled_image(image: np.ndarray) -> bool:
66
+ """
67
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
68
+ """
69
+ if image.dtype == np.uint8:
70
+ return False
71
+
72
+ # It's possible the image has pixel values in [0, 255] but is of floating type
73
+ return np.min(image) >= 0 and np.max(image) <= 1
74
+
75
+
76
+ def make_batched_images(images):
77
+ """
78
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
79
+
80
+ Args:
81
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
82
+ The input image.
83
+
84
+ Returns:
85
+ list: A list of images.
86
+ """
87
+ if (
88
+ isinstance(images, (list, tuple))
89
+ and isinstance(images[0], (list, tuple))
90
+ and is_valid_image(images[0][0])
91
+ ):
92
+ return [img for img_list in images for img in img_list]
93
+
94
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
95
+ return images
96
+
97
+ elif is_valid_image(images):
98
+ return [images]
99
+
100
+ raise ValueError(f"Could not make batched video from {images}")
src/utils/model_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Optional
4
+ from PIL import Image
5
+
6
+
7
+ from src.model.modules.imagecraftconfig import ImageCraftConfig
8
+ from src.model.modules.imagecraftprocessor import (
9
+ ImageCraftProcessor,
10
+ )
11
+
12
+
13
+ def move_inputs_to_device(model_inputs: dict, device: str):
14
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
15
+ return model_inputs
16
+
17
+
18
+ def get_model_inputs(
19
+ processor: ImageCraftProcessor,
20
+ prompt: str,
21
+ image: Image,
22
+ suffix: Optional[str] = None,
23
+ device: str = "cuda",
24
+ ):
25
+ images = [image]
26
+ prompts = [prompt]
27
+ if suffix is not None:
28
+ suffix = [suffix]
29
+ model_inputs = processor(text=prompts, images=images)
30
+ model_inputs = move_inputs_to_device(model_inputs, device)
31
+ return model_inputs
32
+
33
+
34
+ def get_config(config_file="config.json"):
35
+ config = None
36
+ with open(config_file, "r") as f:
37
+ model_config_file = json.load(f)
38
+ config = ImageCraftConfig(**model_config_file)
39
+
40
+ return config
41
+
42
+
43
+ # def load_hf_model(model_path: str, device: str) -> Tuple[ImageCraft, AutoTokenizer]:
44
+
45
+ # # Load the tokenizer
46
+ # tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
47
+ # assert tokenizer.padding_side == "right"
48
+
49
+ # # Find all the *.safetensors files
50
+ # safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
51
+
52
+ # # ... and load them one by one in the tensors dictionary
53
+ # tensors = {}
54
+ # for safetensors_file in safetensors_files:
55
+ # with safe_open(safetensors_file, framework="pt", device="cpu") as f:
56
+ # for key in f.keys():
57
+ # tensors[key] = f.get_tensor(key)
58
+
59
+ # # Load the model's config
60
+ # with open(os.path.join(model_path, "config.json"), "r") as f:
61
+ # model_config_file = json.load(f)
62
+ # config = ImageCraftConfig(**model_config_file)
63
+
64
+ # # Create the model using the configuration
65
+ # model = ImageCraft(config).to(device)
66
+
67
+ # # Load the state dict of the model
68
+ # model.load_state_dict(tensors, strict=False)
69
+
70
+ # # Tie weights
71
+ # model.tie_weights()
72
+
73
+ # return (model, tokenizer)
src/utils/tools.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import pickle
3
+
4
+
5
+ def load_config():
6
+ # Read in the configuration file
7
+ with open("config.yaml") as p:
8
+ config = yaml.safe_load(p)
9
+ return config
10
+
11
+
12
+ def pickle_dump(path, variable):
13
+ # Serialize data from memory to file
14
+ with open(path, "wb") as handle:
15
+ pickle.dump(variable, handle)
16
+
17
+
18
+ def pickle_load(path):
19
+ # Read and load serialized data from file
20
+ with open(path, "rb") as handle:
21
+ loaded = pickle.load(handle)
22
+ return loaded
src/utils/util.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from pathlib import Path
4
+ from tempfile import TemporaryDirectory
5
+ import torch
6
+ import torchaudio
7
+ import random
8
+ import numpy as np
9
+ from PIL import Image
10
+ from urllib.parse import urlparse
11
+ from os.path import exists
12
+ import re
13
+ from num2words import num2words
14
+ import uuid
15
+
16
+ from typing import List, Optional, Dict, Union, Tuple, Iterable
17
+
18
+ from src.utils.image_utils import is_valid_image
19
+
20
+
21
+ IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
22
+ IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
23
+
24
+
25
+ def is_local(url):
26
+ url_parsed = urlparse(url)
27
+ if url_parsed.scheme in ("file", ""):
28
+ return exists(url_parsed.path)
29
+ return False
30
+
31
+
32
+ def replace_numbers_with_words(sentence):
33
+ sentence = re.sub(r"(\d+)", r" \1 ", sentence)
34
+
35
+ def replace_with_words(match):
36
+ num = match.group(0)
37
+ try:
38
+ return num2words(num)
39
+ except:
40
+ return num
41
+
42
+ return re.sub(r"\b\d+\b", replace_with_words, sentence)
43
+
44
+
45
+ def save_to_buffer(audio_tensors, codec_audio_sr):
46
+
47
+ result = torch.cat(audio_tensors, 1)
48
+ buffer = io.BytesIO()
49
+ torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
50
+ buffer.seek(0)
51
+ return buffer.read()
52
+
53
+
54
+ def save_to_file(audio_tensors, codec_audio_sr):
55
+ generated_audio_dir = f"media/voicecraft/generated"
56
+ Path(generated_audio_dir).mkdir(parents=True, exist_ok=True)
57
+ filename = f"{generated_audio_dir}/{str(uuid.uuid4())}.wav"
58
+ tensors = torch.cat(audio_tensors, 1)
59
+ torchaudio.save(filename, tensors, int(codec_audio_sr), format="wav")
60
+ return filename
61
+
62
+
63
+ def split_line_to_sentences(line):
64
+ line = line.strip().capitalize()
65
+ line = line + "." if line and line[-1] not in (".", "!", "?") else line
66
+ sentences = re.findall(r"\w+.*?[.?!]", line.replace("\n", " "), flags=re.S)
67
+ return sentences
68
+
69
+
70
+ def seed_everything(seed=1):
71
+ os.environ["PYTHONHASHSEED"] = str(seed)
72
+ random.seed(seed)
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ torch.cuda.manual_seed(seed)
76
+ torch.backends.cudnn.benchmark = False
77
+ torch.backends.cudnn.deterministic = True
78
+
79
+
80
+ def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_length, image_token):
81
+ return f"{image_token * image_seq_length}{bos_token}{prefix_prompt}\n"
82
+
83
+
84
+ def rescale(
85
+ image: np.ndarray, scale: float, dtype: np.dtype = np.float32
86
+ ) -> np.ndarray:
87
+ rescaled_image = image * scale
88
+ rescaled_image = rescaled_image.astype(dtype)
89
+ return rescaled_image
90
+
91
+
92
+ def resize(
93
+ image: Image,
94
+ size: Tuple[int, int],
95
+ resample: Image.Resampling = None,
96
+ reducing_gap: Optional[int] = None,
97
+ ) -> np.ndarray:
98
+ height, width = size
99
+ resized_image = image.resize(
100
+ (width, height), resample=resample, reducing_gap=reducing_gap
101
+ )
102
+ return resized_image
103
+
104
+
105
+ def normalize(
106
+ image: np.ndarray,
107
+ mean: Union[float, Iterable[float]],
108
+ std: Union[float, Iterable[float]],
109
+ ) -> np.ndarray:
110
+ mean = np.array(mean, dtype=image.dtype)
111
+ std = np.array(std, dtype=image.dtype)
112
+ image = (image - mean) / std
113
+ return image
114
+
115
+
116
+ def process_images(
117
+ images: List[Image.Image],
118
+ size: Dict[str, int] = None,
119
+ resample: Image.Resampling = None,
120
+ rescale_factor: float = None,
121
+ image_mean: Optional[Union[float, List[float]]] = None,
122
+ image_std: Optional[Union[float, List[float]]] = None,
123
+ ) -> List[np.ndarray]:
124
+ height, width = size[0], size[1]
125
+ images = [
126
+ resize(image=image, size=(height, width), resample=resample) for image in images
127
+ ]
128
+ # Convert each image to a numpy array
129
+ images = [np.array(image) for image in images]
130
+ # Rescale the pixel values to be in the range [0, 1]
131
+ images = [rescale(image, scale=rescale_factor) for image in images]
132
+ # Normalize the images to have mean 0 and standard deviation 1
133
+ images = [normalize(image, mean=image_mean, std=image_std) for image in images]
134
+ # Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width]
135
+ images = [image.transpose(2, 0, 1) for image in images]
136
+ return images
137
+
138
+
139
+ def sample_top_p(probs: torch.Tensor, p: float):
140
+ # (B, vocab_size)
141
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
142
+ # (B, vocab_size)
143
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
144
+ # (B, vocab_size)
145
+ # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
146
+ mask = probs_sum - probs_sort > p
147
+ # Zero out all the probabilities of tokens that are not selected by the Top P
148
+ probs_sort[mask] = 0.0
149
+ # Redistribute the probabilities so that they sum up to 1.
150
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
151
+ # Sample a token (its index) from the top p distribution
152
+ next_token = torch.multinomial(probs_sort, num_samples=1)
153
+ # Get the token position in the vocabulary corresponding to the sampled index
154
+ next_token = torch.gather(probs_idx, -1, next_token)
155
+ return next_token
156
+
157
+
158
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
159
+ """
160
+ Args:
161
+ lengths:
162
+ A 1-D tensor containing sentence lengths.
163
+ max_len:
164
+ The length of masks.
165
+ Returns:
166
+ Return a 2-D bool tensor, where masked positions
167
+ are filled with `True` and non-masked positions are
168
+ filled with `False`.
169
+ >>> lengths = torch.tensor([1, 3, 2, 5])
170
+ >>> make_pad_mask(lengths)
171
+ tensor([[False, True, True, True, True],
172
+ [False, False, False, True, True],
173
+ [False, False, True, True, True],
174
+ [False, False, False, False, False]])
175
+ """
176
+ assert lengths.ndim == 1, lengths.ndim
177
+ max_len = max(max_len, lengths.max())
178
+ n = lengths.size(0)
179
+ seq_range = torch.arange(0, max_len, device=lengths.device)
180
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
181
+
182
+ return expaned_lengths >= lengths.unsqueeze(-1)
183
+
184
+
185
+ def _prepare_4d_causal_attention_mask_with_cache_position(
186
+ attention_mask: torch.Tensor,
187
+ sequence_length: int,
188
+ target_length: int,
189
+ dtype: torch.dtype,
190
+ device: torch.device,
191
+ min_dtype: float,
192
+ cache_position: torch.Tensor,
193
+ batch_size: int,
194
+ is_training: bool = False,
195
+ token_type_ids: torch.Tensor = None,
196
+ ):
197
+ """
198
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
199
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
200
+
201
+ Args:
202
+ attention_mask (`torch.Tensor`):
203
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
204
+ sequence_length (`int`):
205
+ The sequence length being processed.
206
+ target_length (`int`):
207
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
208
+ dtype (`torch.dtype`):
209
+ The dtype to use for the 4D attention mask.
210
+ device (`torch.device`):
211
+ The device to plcae the 4D attention mask on.
212
+ min_dtype (`float`):
213
+ The minimum value representable with the dtype `dtype`.
214
+ cache_position (`torch.Tensor`):
215
+ Indices depicting the position of the input sequence tokens in the sequence.
216
+ batch_size (`torch.Tensor`):
217
+ Batch size.
218
+ is_training (`bool`):
219
+ Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
220
+ """
221
+ if attention_mask is not None and attention_mask.dim() == 4:
222
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
223
+ causal_mask = attention_mask
224
+ else:
225
+ causal_mask = torch.full(
226
+ (sequence_length, target_length),
227
+ fill_value=min_dtype,
228
+ dtype=dtype,
229
+ device=device,
230
+ )
231
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
232
+ if sequence_length != 1:
233
+ if is_training:
234
+ causal_mask = torch.triu(causal_mask, diagonal=1)
235
+ else:
236
+ causal_mask[:, :sequence_length] = 0.0
237
+
238
+ causal_mask *= torch.arange(
239
+ target_length, device=cache_position.device
240
+ ) > cache_position.reshape(-1, 1)
241
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
242
+ if attention_mask is not None:
243
+ causal_mask = (
244
+ causal_mask.clone()
245
+ ) # copy to contiguous memory for in-place edit
246
+ mask_length = attention_mask.shape[-1]
247
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
248
+ :, None, None, :
249
+ ].to(causal_mask.device)
250
+ padding_mask = padding_mask == 0
251
+ causal_mask[:, :, :, :mask_length] = causal_mask[
252
+ :, :, :, :mask_length
253
+ ].masked_fill(padding_mask, min_dtype)
254
+ # we are training thus we need to create a full mask on the image + prefix but causal on suffix
255
+ if is_training:
256
+ causal_mask[:, :, :, :mask_length] = causal_mask[
257
+ :, :, :, :mask_length
258
+ ].masked_fill(
259
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
260
+ )
261
+ return causal_mask
262
+
263
+
264
+ # Copied from transformers.models.idefics2.processing_idefics2.is_url
265
+ def is_url(val) -> bool:
266
+ return isinstance(val, str) and val.startswith("http")
267
+
268
+
269
+ # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
270
+ def is_image_or_image_url(elem):
271
+ return is_url(elem) or is_valid_image(elem)
272
+
273
+
274
+ def _is_str_or_image(elem):
275
+ return isinstance(elem, (str)) or is_image_or_image_url(elem)
276
+
277
+
278
+ def generate_partial_autoregressive_mask(sz, start, end):
279
+ mask = torch.zeros(sz, sz).bool()
280
+ mask[start:end, start:end] = torch.triu(
281
+ torch.ones(end - start, end - start, dtype=torch.bool), diagonal=1
282
+ )
283
+ mask[:start, start:end] = True
284
+ mask[end:, start:end] = True
285
+ return mask
286
+
287
+
288
+ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
289
+
290
+ return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
291
+
292
+
293
+ def is_torchdynamo_compiling():
294
+
295
+ try:
296
+ import torch
297
+
298
+ return torch.compiler.is_compiling()
299
+ except Exception:
300
+ try:
301
+ import torch._dynamo as dynamo # noqa: F401
302
+
303
+ return dynamo.is_compiling()
304
+ except Exception:
305
+ return False