KingNish commited on
Commit
b7475bc
·
verified ·
1 Parent(s): 98d025d

Upload 6 files

Browse files
codecmanipulator.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import einops
4
+
5
+
6
+ class CodecManipulator(object):
7
+ r"""
8
+ **mm tokenizer v0.1**
9
+ see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10
+
11
+ text tokens:
12
+ llama tokenizer 0~31999
13
+
14
+ special tokens: "32000": "<EOD>", "32001": "<SOA>", "32002": "<EOA>", "32003": "<SOI>", "32004": "<EOI>", "32005": "<SOV>", "32006": "<EOV>", "32007": "<s_local>", "32008": "<e_local>", "32009": "<s_global>", "32010": "<e_global>", "32011": "<semantic>", "32012": "<acoustic>", "32013": "<low_level>", "32014": "<dac_16k>", "32015": "<dac_44k>", "32016": "<xcodec>", "32017": "<placeholder>", "32018": "<semantic_mert>", "32019": "<semantic_hubert>", "32020": "<visual>", "32021": "<semanticodec>"
15
+
16
+ mm tokens:
17
+ dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18
+ dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19
+ xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20
+ semantic mert: 1024, 57622 - 58645
21
+ semantic hubert: 512, 58646 - 59157
22
+ visual: 64000, not included in v0.1
23
+ semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24
+ """
25
+ def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26
+ self.codec_type = codec_type
27
+ self.mm_v0_2_cfg = {
28
+ "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": ["<dac_16k>"], "fps": 50},
29
+ "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": ["<dac_44k>"]},
30
+ "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": ["<xcodec>"], "fps": 50},
31
+ "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": ["<semantic_mert>"]},
32
+ "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": ["<semantic_hubert>"]},
33
+ "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["<semanticodec>", "<semantic>"]},
34
+ "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["<semanticodec>", "<acoustic>"]},
35
+ "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": ["<semanticodec>"], "fps": 50},
36
+ "special_tokens": {
37
+ '<EOD>': 32000, '<SOA>': 32001, '<EOA>': 32002, '<SOI>': 32003, '<EOI>': 32004, '<SOV>': 32005, '<EOV>': 32006, '<s_local>': 32007, '<e_local>': 32008, '<s_global>': 32009, '<e_global>': 32010, '<semantic>': 32011, '<acoustic>': 32012, '<stage_1>': 32013, '<dac_16k>': 32014, '<dac_44k>': 32015, '<xcodec>': 32016, '<stage_2>': 32017, '<semantic_mert>': 32018, '<semantic_hubert>': 32019, '<visual>': 32020, '<semanticodec>': 32021
38
+ },
39
+ "metadata": {
40
+ "len": 83734,
41
+ "text_range": [0, 31999],
42
+ "special_range": [32000, 32021],
43
+ "mm_range": [32022, 83733]
44
+ },
45
+ "codec_range": {
46
+ "dac16k": [32022, 36117],
47
+ "dac44k": [36118, 45333],
48
+ "xcodec": [45334, 57621],
49
+ # "hifi16k": [53526, 57621],
50
+ "mert": [57622, 58645],
51
+ "hubert": [58646, 59157],
52
+ "semantic/s": [59158, 75541],
53
+ "semantic/a": [75542, 83733],
54
+ "semanticodec": [59158, 83733]
55
+ }
56
+ }
57
+ self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58
+ self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59
+ self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60
+ self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61
+ self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62
+ self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63
+
64
+ self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65
+ self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66
+ self.teacher_forcing = teacher_forcing
67
+ self.data_feature = data_feature
68
+
69
+
70
+ def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71
+ """
72
+ x: (K, T)
73
+ """
74
+ if isinstance(codebook_size, int):
75
+ assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76
+ elif isinstance(codebook_size, list):
77
+ for i, cs in enumerate(codebook_size):
78
+ assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79
+ else:
80
+ raise ValueError(f"codebook_size={codebook_size}")
81
+ assert x.min() >= 0, f"min(x)={x.min()}"
82
+ assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83
+ f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84
+
85
+ _x = x.copy()
86
+ _x = _x.astype(np.uint32)
87
+ cum_offset = 0
88
+ quantizer_begin = self.quantizer_begin
89
+ quantizer_end = quantizer_begin+self.n_quantizer
90
+ for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91
+ if isinstance(codebook_size, int):
92
+ _x[k] += global_offset + k * codebook_size
93
+ elif isinstance(codebook_size, list):
94
+ _x[k] += global_offset + cum_offset
95
+ cum_offset += codebook_size[k]
96
+ else:
97
+ raise ValueError(f"codebook_size={codebook_size}")
98
+ return _x[quantizer_begin:quantizer_end]
99
+
100
+ def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101
+ """
102
+ x: (K, T)
103
+ """
104
+ if isinstance(codebook_size, int):
105
+ assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106
+ elif isinstance(codebook_size, list):
107
+ assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108
+ assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109
+ assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110
+ f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111
+
112
+ _x = x.copy()
113
+ _x = _x.astype(np.uint32)
114
+ cum_offset = 0
115
+ quantizer_begin = self.quantizer_begin
116
+ quantizer_end = quantizer_begin+self.n_quantizer
117
+ for k in range(quantizer_begin, quantizer_end):
118
+ if isinstance(codebook_size, int):
119
+ _x[k-quantizer_begin] -= global_offset + k * codebook_size
120
+ elif isinstance(codebook_size, list):
121
+ _x[k-quantizer_begin] -= global_offset + cum_offset
122
+ cum_offset += codebook_size[k]
123
+ else:
124
+ raise ValueError(f"codebook_size={codebook_size}")
125
+ return _x
126
+
127
+ def flatten(self, x):
128
+ if len(x.shape) > 2:
129
+ x = x.squeeze()
130
+ assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131
+ f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132
+ return einops.rearrange(x, 'K T -> (T K)')
133
+
134
+ def unflatten(self, x, n_quantizer=None):
135
+ x = x.squeeze()
136
+ assert len(x.shape) == 1
137
+ assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
138
+ f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
139
+ if n_quantizer!=self.num_codebooks:
140
+ return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
141
+ return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
142
+
143
+ # def check_codec_type_from_path(self, path):
144
+ # if self.codec_type == "hifi16k":
145
+ # assert "academicodec_hifi_16k_320d_large_uni" in path
146
+
147
+ def get_codec_type_from_range(self, ids):
148
+ ids_range = [ids.min(), ids.max()]
149
+ codec_range = self.mm_v0_2_cfg["codec_range"]
150
+ for codec_type, r in codec_range.items():
151
+ if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
152
+ return codec_type
153
+ raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
154
+
155
+ def npy2ids(self, npy):
156
+ if isinstance(npy, str):
157
+ data = np.load(npy)
158
+ elif isinstance(npy, np.ndarray):
159
+ data = npy
160
+ else:
161
+ raise ValueError(f"not supported type: {type(npy)}")
162
+ # data = data.squeeze()
163
+
164
+ assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
165
+ data = self.offset_tok_ids(
166
+ data,
167
+ global_offset=self.global_offset,
168
+ codebook_size=self.codebook_size,
169
+ num_codebooks=self.num_codebooks,
170
+ )
171
+ data = self.flatten(data)
172
+ codec_range = self.get_codec_type_from_range(data)
173
+ assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
174
+ data = data.tolist()
175
+ return data
176
+
177
+ def ids2npy(self, token_ids):
178
+ # make sure token_ids starts with codebook 0
179
+ if isinstance(self.codebook_size, int):
180
+ codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
181
+ elif isinstance(self.codebook_size, list):
182
+ codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
183
+ assert token_ids[0] >= codebook_0_range[0] \
184
+ and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
185
+ data = np.array(token_ids)
186
+ data = self.unflatten(data, n_quantizer=self.n_quantizer)
187
+ data = self.unoffset_tok_ids(
188
+ data,
189
+ global_offset=self.global_offset,
190
+ codebook_size=self.codebook_size,
191
+ num_codebooks=self.num_codebooks,
192
+ )
193
+ return data
194
+
195
+ def npy_to_json_str(self, npy_path):
196
+ data = self.npy2ids(npy_path)
197
+ return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
198
+
199
+ def sep(self):
200
+ return ''.join(self.sep)
201
+
202
+ def sep_ids(self):
203
+ return self.sep_ids
infer.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
4
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
5
+ import argparse
6
+ import torch
7
+ import numpy as np
8
+ import json
9
+ from omegaconf import OmegaConf
10
+ import torchaudio
11
+ from torchaudio.transforms import Resample
12
+ import soundfile as sf
13
+
14
+ import uuid
15
+ from tqdm import tqdm
16
+ from einops import rearrange
17
+ from codecmanipulator import CodecManipulator
18
+ from mmtokenizer import _MMSentencePieceTokenizer
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
20
+ import glob
21
+ import time
22
+ import copy
23
+ from collections import Counter
24
+ from models.soundstream_hubert_new import SoundStream
25
+ from vocoder import build_codec_model, process_audio
26
+ from post_process_audio import replace_low_freq_with_energy_matched
27
+ import re
28
+
29
+
30
+ parser = argparse.ArgumentParser()
31
+ # Model Configuration:
32
+ parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
33
+ parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
34
+ parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
35
+ parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
36
+ parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
37
+ # Prompt
38
+ parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
39
+ parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
40
+ parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
41
+ parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
42
+ parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
43
+ parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
44
+ # Output
45
+ parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
46
+ parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
47
+ parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
48
+ parser.add_argument("--cuda_idx", type=int, default=0)
49
+ # Config for xcodec and upsampler
50
+ parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
51
+ parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
52
+ parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
53
+ parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
54
+ parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
55
+ parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
56
+
57
+
58
+ args = parser.parse_args()
59
+ if args.use_audio_prompt and not args.audio_prompt_path:
60
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
61
+ stage1_model = args.stage1_model
62
+ stage2_model = args.stage2_model
63
+ cuda_idx = args.cuda_idx
64
+ max_new_tokens = args.max_new_tokens
65
+ stage1_output_dir = os.path.join(args.output_dir, f"stage1")
66
+ stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
67
+ os.makedirs(stage1_output_dir, exist_ok=True)
68
+ os.makedirs(stage2_output_dir, exist_ok=True)
69
+
70
+ # load tokenizer and model
71
+ device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
72
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ stage1_model,
75
+ torch_dtype=torch.bfloat16,
76
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
77
+ )
78
+ # to device, if gpu is available
79
+ model.to(device)
80
+ model.eval()
81
+
82
+ codectool = CodecManipulator("xcodec", 0, 1)
83
+ codectool_stage2 = CodecManipulator("xcodec", 0, 8)
84
+ model_config = OmegaConf.load(args.basic_model_config)
85
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
86
+ parameter_dict = torch.load(args.resume_path, map_location='cpu')
87
+ codec_model.load_state_dict(parameter_dict['codec_model'])
88
+ codec_model.to(device)
89
+ codec_model.eval()
90
+
91
+ class BlockTokenRangeProcessor(LogitsProcessor):
92
+ def __init__(self, start_id, end_id):
93
+ self.blocked_token_ids = list(range(start_id, end_id))
94
+
95
+ def __call__(self, input_ids, scores):
96
+ scores[:, self.blocked_token_ids] = -float("inf")
97
+ return scores
98
+
99
+ def load_audio_mono(filepath, sampling_rate=16000):
100
+ audio, sr = torchaudio.load(filepath)
101
+ # Convert to mono
102
+ audio = torch.mean(audio, dim=0, keepdim=True)
103
+ # Resample if needed
104
+ if sr != sampling_rate:
105
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
106
+ audio = resampler(audio)
107
+ return audio
108
+
109
+ def split_lyrics(lyrics):
110
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
111
+ segments = re.findall(pattern, lyrics, re.DOTALL)
112
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
113
+ return structured_lyrics
114
+
115
+ # Call the function and print the result
116
+ stage1_output_set = []
117
+ # Tips:
118
+ # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
119
+ # all kinds of tags are needed
120
+ with open(args.genre_txt) as f:
121
+ genres = f.read().strip()
122
+ with open(args.lyrics_txt) as f:
123
+ lyrics = split_lyrics(f.read())
124
+ # intruction
125
+ full_lyrics = "\n".join(lyrics)
126
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
127
+ prompt_texts += lyrics
128
+
129
+
130
+ random_id = uuid.uuid4()
131
+ output_seq = None
132
+ # Here is suggested decoding config
133
+ top_p = 0.93
134
+ temperature = 1.0
135
+ repetition_penalty = 1.2
136
+ # special tokens
137
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
138
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
139
+ # Format text prompt
140
+ run_n_segments = min(args.run_n_segments+1, len(lyrics))
141
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
142
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
143
+ guidance_scale = 1.5 if i <=1 else 1.2
144
+ if i==0:
145
+ continue
146
+ if i==1:
147
+ if args.use_audio_prompt:
148
+ audio_prompt = load_audio_mono(args.audio_prompt_path)
149
+ audio_prompt.unsqueeze_(0)
150
+ with torch.no_grad():
151
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
152
+ raw_codes = raw_codes.transpose(0, 1)
153
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
154
+ # Format audio prompt
155
+ code_ids = codectool.npy2ids(raw_codes[0])
156
+ audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
157
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
158
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
159
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
160
+ else:
161
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
162
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
163
+ else:
164
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
165
+
166
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
167
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
168
+ # Use window slicing in case output sequence exceeds the context of model
169
+ max_context = 16384-max_new_tokens-1
170
+ if input_ids.shape[-1] > max_context:
171
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
172
+ input_ids = input_ids[:, -(max_context):]
173
+ with torch.no_grad():
174
+ output_seq = model.generate(
175
+ input_ids=input_ids,
176
+ max_new_tokens=max_new_tokens,
177
+ min_new_tokens=100,
178
+ do_sample=True,
179
+ top_p=top_p,
180
+ temperature=temperature,
181
+ repetition_penalty=repetition_penalty,
182
+ eos_token_id=mmtokenizer.eoa,
183
+ pad_token_id=mmtokenizer.eoa,
184
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
185
+ guidance_scale=guidance_scale,
186
+ )
187
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
188
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
189
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
190
+ if i > 1:
191
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
192
+ else:
193
+ raw_output = output_seq
194
+
195
+ # save raw output and check sanity
196
+ ids = raw_output[0].cpu().numpy()
197
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
198
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
199
+ if len(soa_idx)!=len(eoa_idx):
200
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
201
+
202
+ vocals = []
203
+ instrumentals = []
204
+ range_begin = 1 if args.use_audio_prompt else 0
205
+ for i in range(range_begin, len(soa_idx)):
206
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
207
+ if codec_ids[0] == 32016:
208
+ codec_ids = codec_ids[1:]
209
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
210
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
211
+ vocals.append(vocals_ids)
212
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
213
+ instrumentals.append(instrumentals_ids)
214
+ vocals = np.concatenate(vocals, axis=1)
215
+ instrumentals = np.concatenate(instrumentals, axis=1)
216
+ vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
217
+ inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
218
+ np.save(vocal_save_path, vocals)
219
+ np.save(inst_save_path, instrumentals)
220
+ stage1_output_set.append(vocal_save_path)
221
+ stage1_output_set.append(inst_save_path)
222
+
223
+
224
+ # offload model
225
+ if not args.disable_offload_model:
226
+ model.cpu()
227
+ del model
228
+ torch.cuda.empty_cache()
229
+
230
+ print("Stage 2 inference...")
231
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
232
+ stage2_model,
233
+ torch_dtype=torch.float16,
234
+ attn_implementation="flash_attention_2"
235
+ )
236
+ model_stage2.to(device)
237
+ model_stage2.eval()
238
+
239
+ def stage2_generate(model, prompt, batch_size=16):
240
+ codec_ids = codectool.unflatten(prompt, n_quantizer=1)
241
+ codec_ids = codectool.offset_tok_ids(
242
+ codec_ids,
243
+ global_offset=codectool.global_offset,
244
+ codebook_size=codectool.codebook_size,
245
+ num_codebooks=codectool.num_codebooks,
246
+ ).astype(np.int32)
247
+
248
+ # Prepare prompt_ids based on batch size or single input
249
+ if batch_size > 1:
250
+ codec_list = []
251
+ for i in range(batch_size):
252
+ idx_begin = i * 300
253
+ idx_end = (i + 1) * 300
254
+ codec_list.append(codec_ids[:, idx_begin:idx_end])
255
+
256
+ codec_ids = np.concatenate(codec_list, axis=0)
257
+ prompt_ids = np.concatenate(
258
+ [
259
+ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
260
+ codec_ids,
261
+ np.tile([mmtokenizer.stage_2], (batch_size, 1)),
262
+ ],
263
+ axis=1
264
+ )
265
+ else:
266
+ prompt_ids = np.concatenate([
267
+ np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
268
+ codec_ids.flatten(), # Flatten the 2D array to 1D
269
+ np.array([mmtokenizer.stage_2])
270
+ ]).astype(np.int32)
271
+ prompt_ids = prompt_ids[np.newaxis, ...]
272
+
273
+ codec_ids = torch.as_tensor(codec_ids).to(device)
274
+ prompt_ids = torch.as_tensor(prompt_ids).to(device)
275
+ len_prompt = prompt_ids.shape[-1]
276
+
277
+ block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
278
+
279
+ # Teacher forcing generate loop
280
+ for frames_idx in range(codec_ids.shape[1]):
281
+ cb0 = codec_ids[:, frames_idx:frames_idx+1]
282
+ prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
283
+ input_ids = prompt_ids
284
+
285
+ with torch.no_grad():
286
+ stage2_output = model.generate(input_ids=input_ids,
287
+ min_new_tokens=7,
288
+ max_new_tokens=7,
289
+ eos_token_id=mmtokenizer.eoa,
290
+ pad_token_id=mmtokenizer.eoa,
291
+ logits_processor=block_list,
292
+ )
293
+
294
+ assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
295
+ prompt_ids = stage2_output
296
+
297
+ # Return output based on batch size
298
+ if batch_size > 1:
299
+ output = prompt_ids.cpu().numpy()[:, len_prompt:]
300
+ output_list = [output[i] for i in range(batch_size)]
301
+ output = np.concatenate(output_list, axis=0)
302
+ else:
303
+ output = prompt_ids[0].cpu().numpy()[len_prompt:]
304
+
305
+ return output
306
+
307
+ def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
308
+ stage2_result = []
309
+ for i in tqdm(range(len(stage1_output_set))):
310
+ output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
311
+
312
+ if os.path.exists(output_filename):
313
+ print(f'{output_filename} stage2 has done.')
314
+ continue
315
+
316
+ # Load the prompt
317
+ prompt = np.load(stage1_output_set[i]).astype(np.int32)
318
+
319
+ # Only accept 6s segments
320
+ output_duration = prompt.shape[-1] // 50 // 6 * 6
321
+ num_batch = output_duration // 6
322
+
323
+ if num_batch <= batch_size:
324
+ # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
325
+ output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
326
+ else:
327
+ # If num_batch is greater than batch_size, process in chunks of batch_size
328
+ segments = []
329
+ num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
330
+
331
+ for seg in range(num_segments):
332
+ start_idx = seg * batch_size * 300
333
+ # Ensure the end_idx does not exceed the available length
334
+ end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
335
+ current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
336
+ segment = stage2_generate(
337
+ model,
338
+ prompt[:, start_idx:end_idx],
339
+ batch_size=current_batch_size
340
+ )
341
+ segments.append(segment)
342
+
343
+ # Concatenate all the segments
344
+ output = np.concatenate(segments, axis=0)
345
+
346
+ # Process the ending part of the prompt
347
+ if output_duration*50 != prompt.shape[-1]:
348
+ ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
349
+ output = np.concatenate([output, ending], axis=0)
350
+ output = codectool_stage2.ids2npy(output)
351
+
352
+ # Fix invalid codes (a dirty solution, which may harm the quality of audio)
353
+ # We are trying to find better one
354
+ fixed_output = copy.deepcopy(output)
355
+ for i, line in enumerate(output):
356
+ for j, element in enumerate(line):
357
+ if element < 0 or element > 1023:
358
+ counter = Counter(line)
359
+ most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
360
+ fixed_output[i, j] = most_frequant
361
+ # save output
362
+ np.save(output_filename, fixed_output)
363
+ stage2_result.append(output_filename)
364
+ return stage2_result
365
+
366
+ stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
367
+ print(stage2_result)
368
+ print('Stage 2 DONE.\n')
369
+ # convert audio tokens to audio
370
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
371
+ folder_path = os.path.dirname(path)
372
+ if not os.path.exists(folder_path):
373
+ os.makedirs(folder_path)
374
+ limit = 0.99
375
+ max_val = wav.abs().max()
376
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
377
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
378
+ # reconstruct tracks
379
+ recons_output_dir = os.path.join(args.output_dir, "recons")
380
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
381
+ os.makedirs(recons_mix_dir, exist_ok=True)
382
+ tracks = []
383
+ for npy in stage2_result:
384
+ codec_result = np.load(npy)
385
+ decodec_rlt=[]
386
+ with torch.no_grad():
387
+ decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
388
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
389
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
390
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
391
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
392
+ tracks.append(save_path)
393
+ save_audio(decodec_rlt, save_path, 16000)
394
+ # mix tracks
395
+ for inst_path in tracks:
396
+ try:
397
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
398
+ and 'instrumental' in inst_path:
399
+ # find pair
400
+ vocal_path = inst_path.replace('instrumental', 'vocal')
401
+ if not os.path.exists(vocal_path):
402
+ continue
403
+ # mix
404
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
405
+ vocal_stem, sr = sf.read(inst_path)
406
+ instrumental_stem, _ = sf.read(vocal_path)
407
+ mix_stem = (vocal_stem + instrumental_stem) / 1
408
+ sf.write(recons_mix, mix_stem, sr)
409
+ except Exception as e:
410
+ print(e)
411
+
412
+ # vocoder to upsample audios
413
+ vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
414
+ vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
415
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
416
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
417
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
418
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
419
+ for npy in stage2_result:
420
+ if 'instrumental' in npy:
421
+ # Process instrumental
422
+ instrumental_output = process_audio(
423
+ npy,
424
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
425
+ args.rescale,
426
+ args,
427
+ inst_decoder,
428
+ codec_model
429
+ )
430
+ else:
431
+ # Process vocal
432
+ vocal_output = process_audio(
433
+ npy,
434
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
435
+ args.rescale,
436
+ args,
437
+ vocal_decoder,
438
+ codec_model
439
+ )
440
+ # mix tracks
441
+ try:
442
+ mix_output = instrumental_output + vocal_output
443
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
444
+ save_audio(mix_output, vocoder_mix, 44100, args.rescale)
445
+ print(f"Created mix: {vocoder_mix}")
446
+ except RuntimeError as e:
447
+ print(e)
448
+ print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
449
+
450
+ # Post process
451
+ replace_low_freq_with_energy_matched(
452
+ a_file=recons_mix, # 16kHz
453
+ b_file=vocoder_mix, # 48kHz
454
+ c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
455
+ cutoff_freq=5500.0
456
+ )
mm_tokenizer_v0.2_hf/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee5c7cbf32da93989f14d9ba635e3e1d1ab2cc88a92908a5ed0f149375f6ee49
3
+ size 1761962
mmtokenizer.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+
4
+
5
+ class AbstractTokenizer(ABC):
6
+ """Abstract class for tokenizer."""
7
+
8
+ def __init__(self, name):
9
+ self.name = name
10
+ super().__init__()
11
+
12
+ @property
13
+ @abstractmethod
14
+ def vocab_size(self):
15
+ pass
16
+
17
+ @property
18
+ @abstractmethod
19
+ def vocab(self):
20
+ """Dictionary from vocab text token to id token."""
21
+ pass
22
+
23
+ @property
24
+ @abstractmethod
25
+ def inv_vocab(self):
26
+ """Dictionary from vocab id token to text token."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ def tokenize(self, text):
31
+ pass
32
+
33
+ def detokenize(self, token_ids):
34
+ raise NotImplementedError('detokenizer is not implemented for {} '
35
+ 'tokenizer'.format(self.name))
36
+
37
+ @property
38
+ def cls(self):
39
+ raise NotImplementedError('CLS is not provided for {} '
40
+ 'tokenizer'.format(self.name))
41
+
42
+ @property
43
+ def sep(self):
44
+ raise NotImplementedError('SEP is not provided for {} '
45
+ 'tokenizer'.format(self.name))
46
+
47
+ @property
48
+ def pad(self):
49
+ raise NotImplementedError('PAD is not provided for {} '
50
+ 'tokenizer'.format(self.name))
51
+
52
+ @property
53
+ def eod(self):
54
+ raise NotImplementedError('EOD is not provided for {} '
55
+ 'tokenizer'.format(self.name))
56
+
57
+ @property
58
+ def mask(self):
59
+ raise NotImplementedError('MASK is not provided for {} '
60
+ 'tokenizer'.format(self.name))
61
+
62
+
63
+ class _SentencePieceTokenizer(AbstractTokenizer):
64
+ """SentencePieceTokenizer-Megatron wrapper"""
65
+
66
+ def __init__(self, model_file, vocab_extra_ids=0):
67
+ name = 'SentencePieceTokenizer'
68
+ super().__init__(name)
69
+
70
+ import sentencepiece
71
+ self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
72
+ self._initalize(vocab_extra_ids)
73
+
74
+ def _populate_vocab(self):
75
+ self._vocab = {}
76
+ self._inv_vocab = {}
77
+
78
+ for i in range(len(self.tokenizer)):
79
+ t = self.tokenizer.id_to_piece(i)
80
+ self._inv_vocab[i] = t
81
+ self._vocab[t] = i
82
+
83
+ def _initalize(self, vocab_extra_ids):
84
+ self._populate_vocab()
85
+ self._special_tokens = {}
86
+ self._inv_special_tokens = {}
87
+
88
+ self._t5_tokens = []
89
+
90
+ def _add_special_token(t):
91
+ if t not in self._vocab:
92
+ next_id = len(self._vocab)
93
+ self._vocab[t] = next_id
94
+ self._inv_vocab[next_id] = t
95
+ self._special_tokens[t] = self._vocab[t]
96
+ self._inv_special_tokens[self._vocab[t]] = t
97
+
98
+ _add_special_token('<CLS>')
99
+ self._cls_id = self._vocab['<CLS>']
100
+ _add_special_token('<SEP>')
101
+ self._sep_id = self._vocab['<SEP>']
102
+ _add_special_token('<EOD>')
103
+ self._eod_id = self._vocab['<EOD>']
104
+ _add_special_token('<MASK>')
105
+ self._mask_id = self._vocab['<MASK>']
106
+
107
+ pad_id = self.tokenizer.pad_id()
108
+ try:
109
+ pad_token = self.tokenizer.id_to_piece(pad_id)
110
+ except IndexError:
111
+ pad_token = '<PAD>'
112
+ _add_special_token(pad_token)
113
+ self._pad_id = self._vocab[pad_token]
114
+
115
+ bos_id = self.tokenizer.bos_id()
116
+ try:
117
+ bos_token = self.tokenizer.id_to_piece(bos_id)
118
+ except IndexError:
119
+ bos_token = '<BOS>'
120
+ _add_special_token(bos_token)
121
+ self._bos_id = self._vocab[bos_token]
122
+
123
+ eos_id = self.tokenizer.eos_id()
124
+ try:
125
+ eos_token = self.tokenizer.id_to_piece(eos_id)
126
+ except IndexError:
127
+ eos_token = '<EOS>'
128
+ _add_special_token(eos_token)
129
+ self._eos_id = self._vocab[eos_token]
130
+
131
+ for i in range(vocab_extra_ids):
132
+ t = "<extra_id_{}>".format(i)
133
+ _add_special_token(t)
134
+ self._t5_tokens += [t]
135
+
136
+ @property
137
+ def vocab_size(self):
138
+ return len(self._vocab)
139
+
140
+ @property
141
+ def vocab(self):
142
+ return self._vocab
143
+
144
+ @property
145
+ def inv_vocab(self):
146
+ return self._inv_vocab
147
+
148
+ @property
149
+ def decoder(self):
150
+ return self._inv_vocab
151
+
152
+ @property
153
+ def encoder(self):
154
+ return self._vocab
155
+
156
+ # From:
157
+ # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
158
+ def tokenize(self, text):
159
+ ids = []
160
+ idx = 0
161
+
162
+ while 1:
163
+ indices = {}
164
+ for token in self._special_tokens:
165
+ try:
166
+ indices[token] = text[idx:].index(token)
167
+ except ValueError:
168
+ continue
169
+ if len(indices) == 0:
170
+ break
171
+
172
+ next_token = min(indices, key=indices.get)
173
+ next_idx = idx + indices[next_token]
174
+
175
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
176
+ ids.append(self._special_tokens[next_token])
177
+ idx = next_idx + len(next_token)
178
+
179
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
180
+ return ids
181
+
182
+ # From:
183
+ # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
184
+ def detokenize(self, ids):
185
+ text = ""
186
+ last_i = 0
187
+
188
+ for i, id in enumerate(ids):
189
+ if id in self._inv_special_tokens:
190
+ text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
191
+ text += self._inv_special_tokens[id] + " "
192
+ last_i = i + 1
193
+
194
+ text += self.tokenizer.decode_ids(ids[last_i:])
195
+ return text
196
+
197
+ @property
198
+ def cls(self):
199
+ return self._cls_id
200
+
201
+ @property
202
+ def sep(self):
203
+ return self._sep_id
204
+
205
+ @property
206
+ def pad(self):
207
+ return self._pad_id
208
+
209
+ @property
210
+ def bos_token_id(self):
211
+ return self._bos_id
212
+
213
+ @property
214
+ def bos(self):
215
+ return self._bos_id
216
+
217
+ @property
218
+ def eod(self):
219
+ return self._eod_id
220
+
221
+ @property
222
+ def eos_token_id(self):
223
+ return self._eos_id
224
+
225
+ @property
226
+ def eos(self):
227
+ return self._eos_id
228
+
229
+ @property
230
+ def mask(self):
231
+ return self._mask_id
232
+
233
+ @property
234
+ def additional_special_tokens_ids(self):
235
+ return [self.vocab[k] for k in self._t5_tokens]
236
+
237
+ class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
238
+ """SentencePieceTokenizer-Megatron wrapper"""
239
+
240
+ def __init__(self, model_file, vocab_extra_ids=0):
241
+ super().__init__(model_file, vocab_extra_ids)
242
+
243
+
244
+ def _initalize(self, vocab_extra_ids):
245
+ self._populate_vocab()
246
+ self._special_tokens = {}
247
+ self._inv_special_tokens = {}
248
+
249
+ self._t5_tokens = []
250
+
251
+ def _add_special_token(t):
252
+ if t not in self._vocab:
253
+ next_id = len(self._vocab)
254
+ self._vocab[t] = next_id
255
+ self._inv_vocab[next_id] = t
256
+ self._special_tokens[t] = self._vocab[t]
257
+ self._inv_special_tokens[self._vocab[t]] = t
258
+
259
+ _add_special_token('<CLS>')
260
+ self._cls_id = self._vocab['<CLS>']
261
+ _add_special_token('<SEP>')
262
+ self._sep_id = self._vocab['<SEP>']
263
+ _add_special_token('<EOD>')
264
+ self._eod_id = self._vocab['<EOD>']
265
+ _add_special_token('<MASK>')
266
+ self._mask_id = self._vocab['<MASK>']
267
+
268
+ _add_special_token('<SOA>')
269
+ self._soa_id = self._vocab['<SOA>']
270
+ _add_special_token('<EOA>')
271
+ self._eoa_id = self._vocab['<EOA>']
272
+ _add_special_token('<SOV>')
273
+ self._sov_id = self._vocab['<SOV>']
274
+ _add_special_token('<EOV>')
275
+ self._eov_id = self._vocab['<EOV>']
276
+ _add_special_token('<SOI>')
277
+ self._soi_id = self._vocab['<SOI>']
278
+ _add_special_token('<EOI>')
279
+ self._eoi_id = self._vocab['<EOI>']
280
+ _add_special_token('<s_local>')
281
+ self._s_local_id = self._vocab['<s_local>']
282
+ _add_special_token('<e_local>')
283
+ self._e_local_id = self._vocab['<e_local>']
284
+ _add_special_token('<s_global>')
285
+ self._s_global_id = self._vocab['<s_global>']
286
+ _add_special_token('<e_global>')
287
+ self._e_global_id = self._vocab['<e_global>']
288
+ _add_special_token('<stage_1>')
289
+ self._stage_1_id = self._vocab['<stage_1>']
290
+ _add_special_token('<stage_2>')
291
+ self._stage_2_id = self._vocab['<stage_2>']
292
+ pad_id = self.tokenizer.pad_id()
293
+ try:
294
+ pad_token = self.tokenizer.id_to_piece(pad_id)
295
+ except IndexError:
296
+ pad_token = '<PAD>'
297
+ _add_special_token(pad_token)
298
+ self._pad_id = self._vocab[pad_token]
299
+
300
+ bos_id = self.tokenizer.bos_id()
301
+ try:
302
+ bos_token = self.tokenizer.id_to_piece(bos_id)
303
+ except IndexError:
304
+ bos_token = '<BOS>'
305
+ _add_special_token(bos_token)
306
+ self._bos_id = self._vocab[bos_token]
307
+
308
+ eos_id = self.tokenizer.eos_id()
309
+ try:
310
+ eos_token = self.tokenizer.id_to_piece(eos_id)
311
+ except IndexError:
312
+ eos_token = '<EOS>'
313
+ _add_special_token(eos_token)
314
+ self._eos_id = self._vocab[eos_token]
315
+
316
+ for i in range(vocab_extra_ids):
317
+ t = "<extra_id_{}>".format(i)
318
+ _add_special_token(t)
319
+ self._t5_tokens += [t]
320
+
321
+ @property
322
+ def soa(self):
323
+ return self._soa_id
324
+
325
+ @property
326
+ def eoa(self):
327
+ return self._eoa_id
328
+
329
+ @property
330
+ def sov(self):
331
+ return self._sov_id
332
+
333
+ @property
334
+ def eov(self):
335
+ return self._eov_id
336
+
337
+ @property
338
+ def soi(self):
339
+ return self._soi_id
340
+
341
+ @property
342
+ def eoi(self):
343
+ return self._eoi_id
344
+
345
+ @property
346
+ def s_local(self):
347
+ return self._s_local_id
348
+
349
+ @property
350
+ def e_local(self):
351
+ return self._e_local_id
352
+
353
+ @property
354
+ def s_global(self):
355
+ return self._s_global_id
356
+
357
+ @property
358
+ def e_global(self):
359
+ return self._e_global_id
360
+
361
+ @property
362
+ def stage_1(self):
363
+ return self._stage_1_id
364
+
365
+ @property
366
+ def stage_2(self):
367
+ return self._stage_2_id
prompt_examples/genre.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ inspiring female uplifting pop airy vocal electronic bright vocal vocal
prompt_examples/lyrics.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [verse]
2
+ Staring at the sunset, colors paint the sky
3
+ Thoughts of you keep swirling, can't deny
4
+ I know I let you down, I made mistakes
5
+ But I'm here to mend the heart I didn't break
6
+
7
+ [chorus]
8
+ Every road you take, I'll be one step behind
9
+ Every dream you chase, I'm reaching for the light
10
+ You can't fight this feeling now
11
+ I won't back down
12
+ You know you can't deny it now
13
+ I won't back down
14
+
15
+ [verse]
16
+ They might say I'm foolish, chasing after you
17
+ But they don't feel this love the way we do
18
+ My heart beats only for you, can't you see?
19
+ I won't let you slip away from me
20
+
21
+ [chorus]
22
+ Every road you take, I'll be one step behind
23
+ Every dream you chase, I'm reaching for the light
24
+ You can't fight this feeling now
25
+ I won't back down
26
+ You know you can't deny it now
27
+ I won't back down
28
+
29
+ [bridge]
30
+ No, I won't back down, won't turn around
31
+ Until you're back where you belong
32
+ I'll cross the oceans wide, stand by your side
33
+ Together we are strong
34
+
35
+ [outro]
36
+ Every road you take, I'll be one step behind
37
+ Every dream you chase, love's the tie that binds
38
+ You can't fight this feeling now
39
+ I won't back down