ylacombe commited on
Commit
0810335
·
1 Parent(s): 63486c7

Create vocos_bark.py

Browse files
Files changed (1) hide show
  1. vocos_bark.py +214 -0
vocos_bark.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vocos import Vocos
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ from transformers.models.bark import BarkSemanticModel, BarkCoarseModel, BarkFineModel, BarkPreTrainedModel
5
+ from transformers.models.bark.generation_configuration_bark import (
6
+ BarkCoarseGenerationConfig,
7
+ BarkFineGenerationConfig,
8
+ BarkSemanticGenerationConfig,
9
+ )
10
+ from transformers import BarkConfig
11
+ from transformers.modeling_utils import get_parameter_device
12
+ from transformers.utils import (
13
+ is_accelerate_available,
14
+
15
+ )
16
+
17
+ import torch
18
+
19
+ class BarkModel(BarkPreTrainedModel):
20
+ config_class = BarkConfig
21
+
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+
25
+ self.semantic = BarkSemanticModel(config.semantic_config)
26
+ self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
27
+ self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
28
+
29
+ self.vocos = Vocos.from_pretrained("hubertsiuzdak/vocos-encodec-24khz-v2")
30
+ self.config = config
31
+
32
+ @property
33
+ def device(self) -> torch.device:
34
+ """
35
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
36
+ device).
37
+ """
38
+ # for bark_model, device must be verified on its sub-models
39
+ # if has _hf_hook, has been offloaded so the device has to be found in the hook
40
+ if not hasattr(self.semantic, "_hf_hook"):
41
+ return get_parameter_device(self)
42
+ for module in self.semantic.modules():
43
+ if (
44
+ hasattr(module, "_hf_hook")
45
+ and hasattr(module._hf_hook, "execution_device")
46
+ and module._hf_hook.execution_device is not None
47
+ ):
48
+ return torch.device(module._hf_hook.execution_device)
49
+
50
+ def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
51
+ r"""
52
+ Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
53
+ method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
54
+ the next sub-model runs.
55
+
56
+ Args:
57
+ gpu_id (`int`, *optional*, defaults to 0):
58
+ GPU id on which the sub-models will be loaded and offloaded.
59
+ """
60
+ if is_accelerate_available():
61
+ from accelerate import cpu_offload_with_hook
62
+ else:
63
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
64
+
65
+ device = torch.device(f"cuda:{gpu_id}")
66
+
67
+ if self.device.type != "cpu":
68
+ self.to("cpu")
69
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
70
+
71
+ # this layer is used outside the first foward pass of semantic so need to be loaded before semantic
72
+ self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
73
+
74
+ hook = None
75
+ for cpu_offloaded_model in [
76
+ self.semantic,
77
+ self.coarse_acoustics,
78
+ self.fine_acoustics,
79
+ ]:
80
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
81
+
82
+ self.fine_acoustics_hook = hook
83
+
84
+ _, hook = cpu_offload_with_hook(self.vocos, device, prev_module_hook=hook)
85
+
86
+ # We'll offload the last model manually.
87
+ self.codec_model_hook = hook
88
+
89
+
90
+
91
+ @torch.no_grad()
92
+ def generate(
93
+ self,
94
+ input_ids: Optional[torch.Tensor] = None,
95
+ history_prompt: Optional[Dict[str, torch.Tensor]] = None,
96
+ **kwargs,
97
+ ) -> torch.LongTensor:
98
+ """
99
+ Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
100
+
101
+ Args:
102
+ input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
103
+ Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
104
+ longest generation among the batch.
105
+ history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
106
+ Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
107
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
108
+
109
+ - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
110
+ - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
111
+ semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
112
+
113
+ This means you can, for example, specify a generation strategy for all sub-models except one.
114
+ Returns:
115
+ torch.LongTensor: Output generated audio.
116
+
117
+ Example:
118
+
119
+ ```python
120
+ >>> from transformers import AutoProcessor, BarkModel
121
+
122
+ >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
123
+ >>> model = BarkModel.from_pretrained("suno/bark-small")
124
+
125
+ >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
126
+ >>> voice_preset = "v2/en_speaker_6"
127
+
128
+ >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
129
+
130
+ >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
131
+ >>> audio_array = audio_array.cpu().numpy().squeeze()
132
+ ```
133
+ """
134
+ # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
135
+ # todo: dict
136
+ semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
137
+ coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
138
+ fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
139
+
140
+ kwargs_semantic = {
141
+ # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
142
+ "attention_mask": kwargs.pop("attention_mask", None)
143
+ }
144
+ kwargs_coarse = {}
145
+ kwargs_fine = {}
146
+ for key, value in kwargs.items():
147
+ if key.startswith("semantic_"):
148
+ key = key[len("semantic_") :]
149
+ kwargs_semantic[key] = value
150
+ elif key.startswith("coarse_"):
151
+ key = key[len("coarse_") :]
152
+ kwargs_coarse[key] = value
153
+ elif key.startswith("fine_"):
154
+ key = key[len("fine_") :]
155
+ kwargs_fine[key] = value
156
+ else:
157
+ # If the key is already in a specific config, then it's been set with a
158
+ # submodules specific value and we don't override
159
+ if key not in kwargs_semantic:
160
+ kwargs_semantic[key] = value
161
+ if key not in kwargs_coarse:
162
+ kwargs_coarse[key] = value
163
+ if key not in kwargs_fine:
164
+ kwargs_fine[key] = value
165
+
166
+ # 1. Generate from the semantic model
167
+ semantic_output = self.semantic.generate(
168
+ input_ids,
169
+ history_prompt=history_prompt,
170
+ semantic_generation_config=semantic_generation_config,
171
+ **kwargs_semantic,
172
+ )
173
+
174
+ # 2. Generate from the coarse model
175
+ coarse_output = self.coarse_acoustics.generate(
176
+ semantic_output,
177
+ history_prompt=history_prompt,
178
+ semantic_generation_config=semantic_generation_config,
179
+ coarse_generation_config=coarse_generation_config,
180
+ codebook_size=self.generation_config.codebook_size,
181
+ **kwargs_coarse,
182
+ )
183
+
184
+ # 3. "generate" from the fine model
185
+ output = self.fine_acoustics.generate(
186
+ coarse_output,
187
+ history_prompt=history_prompt,
188
+ semantic_generation_config=semantic_generation_config,
189
+ coarse_generation_config=coarse_generation_config,
190
+ fine_generation_config=fine_generation_config,
191
+ codebook_size=self.generation_config.codebook_size,
192
+ **kwargs_fine,
193
+ )
194
+
195
+ if getattr(self, "fine_acoustics_hook", None) is not None:
196
+ # Manually offload fine_acoustics to CPU
197
+ # and load codec_model to GPU
198
+ # since bark doesn't use codec_model forward pass
199
+ self.fine_acoustics_hook.offload()
200
+ self.vocos = self.vocos.to(self.device)
201
+
202
+ # 4. Decode the output and generate audio array
203
+ bandwidth_id = torch.tensor([2]).to(self.device)
204
+ # transpose
205
+ value = output.transpose(0,1)
206
+ value = self.vocos.codes_to_features(value)
207
+ value = self.vocos.decode(value, bandwidth_id=bandwidth_id)
208
+
209
+ if getattr(self, "codec_model_hook", None) is not None:
210
+ # Offload codec_model to CPU
211
+ self.vocos.offload()
212
+
213
+
214
+ return value