unpairedelectron07 commited on
Commit
797349c
·
verified ·
1 Parent(s): 4eb202f

Upload 11 files

Browse files
audiocraft/utils/autocast.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ class TorchAutocast:
11
+ """TorchAutocast utility class.
12
+ Allows you to enable and disable autocast. This is specially useful
13
+ when dealing with different architectures and clusters with different
14
+ levels of support.
15
+
16
+ Args:
17
+ enabled (bool): Whether to enable torch.autocast or not.
18
+ args: Additional args for torch.autocast.
19
+ kwargs: Additional kwargs for torch.autocast
20
+ """
21
+ def __init__(self, enabled: bool, *args, **kwargs):
22
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
23
+
24
+ def __enter__(self):
25
+ if self.autocast is None:
26
+ return
27
+ try:
28
+ self.autocast.__enter__()
29
+ except RuntimeError:
30
+ device = self.autocast.device
31
+ dtype = self.autocast.fast_dtype
32
+ raise RuntimeError(
33
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
34
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
35
+ )
36
+
37
+ def __exit__(self, *args, **kwargs):
38
+ if self.autocast is None:
39
+ return
40
+ self.autocast.__exit__(*args, **kwargs)
audiocraft/utils/best_state.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ import logging
9
+ import typing as tp
10
+
11
+ import flashy
12
+ import torch
13
+
14
+ from ..optim import ModuleDictEMA
15
+ from .utils import copy_state
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class BestStateDictManager(flashy.state.StateDictSource):
22
+ """BestStateDictManager maintains a copy of best state_dict() for registered sources.
23
+
24
+ BestStateDictManager has two main attributes:
25
+ states (dict): State dict of the registered StateDictSource.
26
+ param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
27
+
28
+ When registering new sources, the BestStateDictManager will ensure two conflicting sources between
29
+ ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
30
+ what to consider for best state.
31
+
32
+ Args:
33
+ device (torch.device or str): Device on which we keep the copy.
34
+ dtype (torch.dtype): Data type for the state parameters.
35
+ """
36
+ def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
37
+ dtype: tp.Optional[torch.dtype] = None):
38
+ self.device = device
39
+ self.states: dict = {}
40
+ self.param_ids: dict = defaultdict(dict)
41
+ self.dtype = dtype
42
+
43
+ def _get_parameter_ids(self, state_dict):
44
+ return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
45
+
46
+ def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
47
+ for registered_name, registered_param_ids in self.param_ids.items():
48
+ if registered_name != name:
49
+ overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
50
+ assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
51
+ f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
52
+
53
+ def update(self, name: str, source: flashy.state.StateDictSource):
54
+ if name not in self.states:
55
+ raise ValueError(f"{name} missing from registered states.")
56
+ self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
57
+
58
+ def register(self, name: str, source: flashy.state.StateDictSource):
59
+ if name in self.states:
60
+ raise ValueError(f"{name} already present in states.")
61
+ # Registering parameter ids for EMA and non-EMA states allows us to check that
62
+ # there is no overlap that would create ambiguity about how to handle the best state
63
+ param_ids = self._get_parameter_ids(source.state_dict())
64
+ if isinstance(source, ModuleDictEMA):
65
+ logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
66
+ self._validate_no_parameter_ids_overlap(name, param_ids)
67
+ self.param_ids[name] = param_ids
68
+ else:
69
+ logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
70
+ self._validate_no_parameter_ids_overlap('base', param_ids)
71
+ self.param_ids['base'].update(param_ids)
72
+ # Register state
73
+ self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
74
+
75
+ def state_dict(self) -> flashy.state.StateDict:
76
+ return self.states
77
+
78
+ def load_state_dict(self, state: flashy.state.StateDict):
79
+ for name, sub_state in state.items():
80
+ for k, v in sub_state.items():
81
+ self.states[name][k].copy_(v)
audiocraft/utils/cache.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from collections import deque
9
+ from functools import partial
10
+ from hashlib import sha1
11
+ import logging
12
+ from pathlib import Path
13
+ import sys
14
+ import typing as tp
15
+ import zipfile
16
+
17
+ import flashy
18
+ import torch
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
25
+ """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
26
+ This method can be used in case there is no need in extracting a chunk of the full embedding
27
+ read from the cache.
28
+
29
+ Args:
30
+ full_embed (torch.Tensor): The full embedding.
31
+ x (any): Batch object from which the full embedding is derived.
32
+ idx (torch.Tensor): Index of object to consider in the batch object.
33
+ Returns:
34
+ full_embed (torch.Tensor): The full embedding
35
+ """
36
+ return full_embed.to(device)
37
+
38
+
39
+ class EmbeddingCache:
40
+ """Cache around embeddings computation for faster execution.
41
+ The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
42
+ to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
43
+ using a user-provided function. When the cache is warm (all embeddings are pre-computed),
44
+ the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
45
+ Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
46
+ and synchronization points in the forward calls.
47
+
48
+ Args:
49
+ cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
50
+ device (str or torch.device): Device on which the embedding is returned.
51
+ compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
52
+ the embedding from a given object and path. This user provided function can compute the
53
+ embedding from the provided object or using the provided path as entry point. The last parameter
54
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
55
+ extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
56
+ the desired embedding chunk from the full embedding loaded from the cache. The last parameter
57
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
58
+ If not specified, will return the full embedding unmodified.
59
+ """
60
+ def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
61
+ compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
62
+ extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
63
+ self.cache_path = Path(cache_path)
64
+ self.device = device
65
+ self._compute_embed_fn = compute_embed_fn
66
+ self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
67
+ if extract_embed_fn is not None:
68
+ self._extract_embed_fn = extract_embed_fn
69
+ else:
70
+ self._extract_embed_fn = partial(get_full_embed, device=device)
71
+ if self.cache_path is not None:
72
+ self.cache_path.mkdir(exist_ok=True, parents=True)
73
+ logger.info(f"Cache instantiated at: {self.cache_path}")
74
+ self.pool = ThreadPoolExecutor(8)
75
+ self.pool.__enter__()
76
+ self._current_batch_cache: dict = {}
77
+ self._memory_cache: dict = {}
78
+
79
+ def _get_cache_path(self, path: tp.Union[Path, str]):
80
+ """Get cache path for the given file path."""
81
+ sig = sha1(str(path).encode()).hexdigest()
82
+ return self.cache_path / sig
83
+
84
+ @staticmethod
85
+ def _get_full_embed_from_cache(cache: Path):
86
+ """Loads full pre-computed embedding from the cache."""
87
+ try:
88
+ embed = torch.load(cache, 'cpu')
89
+ except Exception as exc:
90
+ logger.error("Error loading %s: %r", cache, exc)
91
+ embed = None
92
+ return embed
93
+
94
+ def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
95
+ """Get embedding from cache, computing and storing it to cache if not already cached.
96
+ The EmbeddingCache first tries to load the embedding from the in-memory cache
97
+ containing the pre-computed chunks populated through `populate_embed_cache`.
98
+ If not found, the full embedding is computed and stored on disk to be later accessed
99
+ to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
100
+
101
+ Args:
102
+ paths (list[Path or str]): List of paths from where the embeddings can be loaded.
103
+ x (any): Object from which the embedding is extracted.
104
+ """
105
+ embeds = []
106
+ for idx, path in enumerate(paths):
107
+ cache = self._get_cache_path(path)
108
+ if cache in self._current_batch_cache:
109
+ embed = self._current_batch_cache[cache]
110
+ else:
111
+ full_embed = self._compute_embed_fn(path, x, idx)
112
+ try:
113
+ with flashy.utils.write_and_rename(cache, pid=True) as f:
114
+ torch.save(full_embed.cpu(), f)
115
+ except Exception as exc:
116
+ logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
117
+ else:
118
+ logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
119
+ embed = self._extract_embed_fn(full_embed, x, idx)
120
+ embeds.append(embed)
121
+ embed = torch.stack(embeds, dim=0)
122
+ return embed
123
+
124
+ def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
125
+ """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
126
+ The in-memory caches consist in a cache for the full embedding and another cache for the
127
+ final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
128
+ and reduce the IO footprint and synchronization points during forward passes.
129
+
130
+ Args:
131
+ paths (list[Path]): List of paths from where the embeddings can be loaded.
132
+ x (any): Object from which the embedding is extracted.
133
+ """
134
+ self._current_batch_cache.clear()
135
+ if self.cache_path is not None:
136
+ futures: list = []
137
+ for path in paths:
138
+ assert path is not None, "Path is required for computation from cache"
139
+ cache = self._get_cache_path(path)
140
+ if cache in self._memory_cache or not cache.exists():
141
+ futures.append(None)
142
+ else:
143
+ futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
144
+ for idx, (path, future) in enumerate(zip(paths, futures)):
145
+ assert path is not None
146
+ cache = self._get_cache_path(path)
147
+ full_embed = None
148
+ if future is None:
149
+ if cache in self._memory_cache:
150
+ full_embed = self._memory_cache[cache]
151
+ else:
152
+ full_embed = future.result()
153
+ if full_embed is not None:
154
+ self._memory_cache[cache] = full_embed
155
+ full_embed = full_embed.to(self.device)
156
+ if full_embed is not None:
157
+ embed = self._extract_embed_fn(full_embed, x, idx)
158
+ self._current_batch_cache[cache] = embed
159
+
160
+
161
+ class CachedBatchWriter:
162
+ """Write pre computed caches for mini batches. This can
163
+ make loading a lot more efficient depending on your filesystem.
164
+
165
+ Args:
166
+ cache_folder (Path): folder in which the cached minibatches
167
+ will be stored.
168
+
169
+ Inside cache folder, the structure is the following:
170
+ `epoch_number / update_number.zip`
171
+ And the zip file contains one entry per batch item.
172
+
173
+ It is possible to use the cache with a batch size smaller than
174
+ created with but obviously not larger. Make sure to call the
175
+ `start_epoch(epoch)` method for indicating changes of epochs.
176
+
177
+ See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
178
+ for an example of how to warmup the cache.
179
+ """
180
+ def __init__(self, cache_folder: Path):
181
+ self.cache_folder = cache_folder
182
+ self._current_epoch: tp.Optional[int] = None
183
+ self._current_index = 0
184
+
185
+ def start_epoch(self, epoch: int):
186
+ """Call at the beginning of each epoch.
187
+ """
188
+ self._current_epoch = epoch
189
+ self._current_index = 0
190
+ self._zip_path.parent.mkdir(exist_ok=True, parents=True)
191
+
192
+ @staticmethod
193
+ def _get_zip_path(cache_folder: Path, epoch: int, index: int):
194
+ return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
195
+
196
+ @property
197
+ def _zip_path(self):
198
+ assert self._current_epoch is not None
199
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
200
+
201
+ def save(self, *content):
202
+ """Save one mini batch. This function is distributed-aware
203
+ and will automatically merge all the items from the different
204
+ workers.
205
+ """
206
+ all_contents = []
207
+ for rank in range(flashy.distrib.world_size()):
208
+ their_content = flashy.distrib.broadcast_object(content, src=rank)
209
+ all_contents.append(their_content)
210
+
211
+ if flashy.distrib.is_rank_zero():
212
+ idx = 0
213
+ with flashy.utils.write_and_rename(self._zip_path) as tmp:
214
+ with zipfile.ZipFile(tmp, 'w') as zf:
215
+ for content in all_contents:
216
+ for vals in zip(*content):
217
+ with zf.open(f'{idx}', 'w') as f: # type: ignore
218
+ torch.save(vals, f)
219
+ idx += 1
220
+ flashy.distrib.barrier()
221
+ self._current_index += 1
222
+
223
+
224
+ class CachedBatchLoader:
225
+ """Loader for cached mini-batches dumped with `CachedBatchWriter`.
226
+
227
+ Args:
228
+ cache_folder (Path): folder in which the cached minibatches are stored.
229
+ batch_size (int): batch size (per GPU) expected.
230
+ num_workers (int): number of workers to use for loading.
231
+ min_length (int): minimum expected length for each epoch. If some
232
+ mini-batches are missing, and error is raised.
233
+
234
+ This is iterable just like a regular DataLoader.
235
+ """
236
+
237
+ def __init__(self, cache_folder: Path, batch_size: int,
238
+ num_workers: int = 10, min_length: int = 1):
239
+ self.cache_folder = cache_folder
240
+ self.batch_size = batch_size
241
+ self.num_workers = num_workers
242
+ self.min_length = min_length
243
+ self._current_epoch: tp.Optional[int] = None
244
+ self.sampler = None # for compatibility with the regular DataLoader
245
+
246
+ def __len__(self):
247
+ path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
248
+ return len([p for p in path.iterdir() if p.suffix == ".zip"])
249
+
250
+ def start_epoch(self, epoch: int):
251
+ """Call at the beginning of each epoch.
252
+ """
253
+ self._current_epoch = epoch
254
+
255
+ def _zip_path(self, index: int):
256
+ assert self._current_epoch is not None
257
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
258
+
259
+ def _load_one(self, index: int):
260
+ zip_path = self._zip_path(index)
261
+ if not zip_path.exists():
262
+ if index < self.min_length:
263
+ raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
264
+
265
+ return None
266
+ mode = "rb" if sys.version_info >= (3, 9) else "r"
267
+ try:
268
+ with zipfile.ZipFile(zip_path, 'r') as zf:
269
+ rank = flashy.distrib.rank()
270
+ world_size = flashy.distrib.world_size()
271
+ root = zipfile.Path(zf)
272
+ items = list(root.iterdir())
273
+ total_batch_size = self.batch_size * world_size
274
+ if len(items) < total_batch_size:
275
+ raise RuntimeError(
276
+ f"The cache can handle a max batch size of {len(items)}, "
277
+ f"but {total_batch_size} is needed.")
278
+ start = rank * self.batch_size
279
+ items = items[start: start + self.batch_size]
280
+ assert len(items) == self.batch_size
281
+ entries = []
282
+ entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
283
+ transposed = zip(*entries)
284
+ out = []
285
+ for part in transposed:
286
+ assert len(part) > 0
287
+ if isinstance(part[0], torch.Tensor):
288
+ out.append(torch.stack(part))
289
+ else:
290
+ assert isinstance(part, torch.Tensor)
291
+ out.append(part)
292
+ return out
293
+ except Exception:
294
+ logger.error("Error when reading zip path %s", zip_path)
295
+ raise
296
+
297
+ def __iter__(self):
298
+ """This will yields tuples, exactly as provided to the
299
+ `CachedBatchWriter.save` method.
300
+ """
301
+ pool = ThreadPoolExecutor(self.num_workers)
302
+ next_index = 0
303
+ queue = deque()
304
+
305
+ def _get_next():
306
+ nonlocal next_index
307
+ r = queue.popleft().result()
308
+ if r is None:
309
+ return None
310
+ else:
311
+ queue.append(pool.submit(self._load_one, next_index))
312
+ next_index += 1
313
+ return r
314
+
315
+ with pool:
316
+ # fill the buffer of fetching jobs.
317
+ for _ in range(2 * self.num_workers):
318
+ queue.append(pool.submit(self._load_one, next_index))
319
+ next_index += 1
320
+ while True:
321
+ batch = _get_next()
322
+ if batch is None:
323
+ return
324
+ yield batch
audiocraft/utils/checkpoint.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from enum import Enum
8
+ import logging
9
+ from pathlib import Path
10
+ import re
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+
16
+ from ..environment import AudioCraftEnvironment
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class CheckpointSource(Enum):
23
+ CURRENT_XP = "current_xp"
24
+ PRETRAINED = "pretrained"
25
+ OTHER = "other"
26
+
27
+
28
+ def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
29
+ """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
30
+ `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
31
+ 'best' for the best checkpoint or the epoch number.
32
+
33
+ Args:
34
+ name (str, optional): Name suffix for the checkpoint file stem.
35
+ rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
36
+ use_fsdp (bool): Whether the calling solver relies on FSDP.
37
+ Returns:
38
+ str: The checkpoint name.
39
+ """
40
+ suffix = ''
41
+ if rank is None:
42
+ rank = flashy.distrib.rank()
43
+ if rank > 0 and use_fsdp:
44
+ suffix = '.' + str(rank)
45
+ name_part = ''
46
+ if name is not None:
47
+ name_part = f'_{name}'
48
+ return f'checkpoint{name_part}.th{suffix}'
49
+
50
+
51
+ def is_sharded_checkpoint(path: Path) -> bool:
52
+ """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
53
+ return re.search(r'\.th\.\d+$', path.name) is not None
54
+
55
+
56
+ def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
57
+ use_fsdp: bool = False) -> tp.Optional[Path]:
58
+ """Resolve a given checkpoint path for a provided dora sig or path.
59
+
60
+ Args:
61
+ sig_or_path (Path or str): Checkpoint path or dora signature.
62
+ name (str, optional): Name suffix for the checkpoint file stem.
63
+ rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
64
+ use_fsdp (bool): Whether the calling solver relies on FSDP.
65
+ Returns:
66
+ Path, optional: Resolved checkpoint path, if it exists.
67
+ """
68
+ from audiocraft import train
69
+ xps_root = train.main.dora.dir / 'xps'
70
+ sig_or_path = str(sig_or_path)
71
+ if sig_or_path.startswith('//sig/'):
72
+ sig = sig_or_path[len('//sig/'):]
73
+ path = xps_root / sig
74
+ else:
75
+ path = Path(sig_or_path)
76
+ path = AudioCraftEnvironment.resolve_reference_path(path)
77
+
78
+ if path.is_dir():
79
+ path = path / checkpoint_name(name, use_fsdp=use_fsdp)
80
+
81
+ if path.exists():
82
+ return path
83
+ else:
84
+ return None
85
+
86
+
87
+ def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
88
+ """Load state from checkpoints at the specified checkpoint path."""
89
+ if is_sharded:
90
+ rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
91
+ if rank0_checkpoint_path.exists():
92
+ check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
93
+ state = torch.load(checkpoint_path, 'cpu')
94
+ logger.info("Checkpoint loaded from %s", checkpoint_path)
95
+ return state
96
+
97
+
98
+ def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
99
+ """Save state to disk to the specified checkpoint_path."""
100
+ _safe_save_checkpoint(state, checkpoint_path, is_sharded)
101
+ logger.info("Checkpoint saved to %s", checkpoint_path)
102
+
103
+
104
+ def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
105
+ """Flush checkpoints to only keep last N checkpoints."""
106
+ if keep_last is None or keep_last <= 0:
107
+ return
108
+ checkpoint_dir = checkpoint_path.parent
109
+ suffix = ''
110
+ if flashy.distrib.rank() > 0:
111
+ suffix = f'.{flashy.distrib.rank()}'
112
+ checkpoint_files_with_epoch = []
113
+ for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
114
+ epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
115
+ if epoch_part.isdigit():
116
+ checkpoint_files_with_epoch.append((path, int(epoch_part)))
117
+ checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
118
+ total_to_flush = max(0, len(checkpoint_files) - keep_last)
119
+ files_to_flush = checkpoint_files[:total_to_flush]
120
+ for path in files_to_flush:
121
+ logger.debug("Removing checkpoint: %s", str(path))
122
+ path.unlink(missing_ok=True)
123
+
124
+
125
+ def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
126
+ """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
127
+ # Finish the work of a previous run that got interrupted while dumping.
128
+ old_path = Path(str(checkpoint_path) + '.old')
129
+ if old_path.exists():
130
+ raise RuntimeError(
131
+ f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
132
+ token = Path(str(rank0_checkpoint_path) + '.tmp.done')
133
+ tmp_path = Path(str(checkpoint_path) + '.tmp')
134
+ if token.exists():
135
+ if tmp_path.exists():
136
+ tmp_path.rename(checkpoint_path)
137
+ flashy.distrib.barrier()
138
+ if flashy.distrib.is_rank_zero() and token.exists():
139
+ token.unlink()
140
+
141
+
142
+ def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
143
+ """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
144
+ def _barrier_if_sharded():
145
+ if is_sharded:
146
+ flashy.distrib.barrier()
147
+
148
+ if flashy.distrib.is_rank_zero():
149
+ token = Path(str(checkpoint_path) + '.tmp.done')
150
+ if token.exists():
151
+ token.unlink()
152
+ _barrier_if_sharded()
153
+ with flashy.utils.write_and_rename(checkpoint_path) as f:
154
+ torch.save(state, f)
155
+ _barrier_if_sharded()
156
+ if flashy.distrib.is_rank_zero():
157
+ token.touch()
158
+ _barrier_if_sharded()
159
+ _barrier_if_sharded()
160
+ if flashy.distrib.rank() == 0:
161
+ token.unlink()
audiocraft/utils/cluster.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions for SLURM configuration and cluster settings.
9
+ """
10
+
11
+ from enum import Enum
12
+ import os
13
+ import socket
14
+ import typing as tp
15
+
16
+ import omegaconf
17
+
18
+
19
+ class ClusterType(Enum):
20
+ AWS = "aws"
21
+ FAIR = "fair"
22
+ RSC = "rsc"
23
+ LOCAL_DARWIN = "darwin"
24
+ DEFAULT = "default" # used for any other cluster.
25
+
26
+
27
+ def _guess_cluster_type() -> ClusterType:
28
+ uname = os.uname()
29
+ fqdn = socket.getfqdn()
30
+ if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
31
+ return ClusterType.AWS
32
+
33
+ if fqdn.endswith(".fair"):
34
+ return ClusterType.FAIR
35
+
36
+ if fqdn.endswith(".facebook.com"):
37
+ return ClusterType.RSC
38
+
39
+ if uname.sysname == "Darwin":
40
+ return ClusterType.LOCAL_DARWIN
41
+
42
+ return ClusterType.DEFAULT
43
+
44
+
45
+ def get_cluster_type(
46
+ cluster_type: tp.Optional[ClusterType] = None,
47
+ ) -> tp.Optional[ClusterType]:
48
+ if cluster_type is None:
49
+ return _guess_cluster_type()
50
+
51
+ return cluster_type
52
+
53
+
54
+ def get_slurm_parameters(
55
+ cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
56
+ ) -> omegaconf.DictConfig:
57
+ """Update SLURM parameters in configuration based on cluster type.
58
+ If the cluster type is not specify, it infers it automatically.
59
+ """
60
+ from ..environment import AudioCraftEnvironment
61
+ cluster_type = get_cluster_type(cluster_type)
62
+ # apply cluster-specific adjustments
63
+ if cluster_type == ClusterType.AWS:
64
+ cfg["mem_per_gpu"] = None
65
+ cfg["constraint"] = None
66
+ cfg["setup"] = []
67
+ elif cluster_type == ClusterType.RSC:
68
+ cfg["mem_per_gpu"] = None
69
+ cfg["setup"] = []
70
+ cfg["constraint"] = None
71
+ cfg["partition"] = "learn"
72
+ slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
73
+ if slurm_exclude is not None:
74
+ cfg["exclude"] = slurm_exclude
75
+ return cfg
audiocraft/utils/deadlock.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ from queue import Queue, Empty
10
+ import signal
11
+ import sys
12
+ import threading
13
+ import traceback
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DeadlockDetect:
19
+ def __init__(self, use: bool = False, timeout: float = 120.):
20
+ self.use = use
21
+ self.timeout = timeout
22
+ self._queue: Queue = Queue()
23
+
24
+ def update(self, stage: str):
25
+ if self.use:
26
+ self._queue.put(stage)
27
+
28
+ def __enter__(self):
29
+ if self.use:
30
+ self._thread = threading.Thread(target=self._detector_thread)
31
+ self._thread.start()
32
+
33
+ def __exit__(self, exc_type, exc_val, exc_tb):
34
+ if self.use:
35
+ self._queue.put(None)
36
+ self._thread.join()
37
+
38
+ def _detector_thread(self):
39
+ logger.debug("Deadlock detector started")
40
+ last_stage = "init"
41
+ while True:
42
+ try:
43
+ stage = self._queue.get(timeout=self.timeout)
44
+ except Empty:
45
+ break
46
+ if stage is None:
47
+ logger.debug("Exiting deadlock detector thread")
48
+ return
49
+ else:
50
+ last_stage = stage
51
+ logger.error("Deadlock detector timed out, last stage was %s", last_stage)
52
+ for th in threading.enumerate():
53
+ print(th, file=sys.stderr)
54
+ traceback.print_stack(sys._current_frames()[th.ident])
55
+ print(file=sys.stderr)
56
+ sys.stdout.flush()
57
+ sys.stderr.flush()
58
+ os.kill(os.getpid(), signal.SIGKILL)
audiocraft/utils/export.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility to export a training checkpoint to a lightweight release checkpoint.
9
+ """
10
+
11
+ from pathlib import Path
12
+ import typing as tp
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch
16
+
17
+ from audiocraft import __version__
18
+
19
+
20
+ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
21
+ """Export only the best state from the given EnCodec checkpoint. This
22
+ should be used if you trained your own EnCodec model.
23
+ """
24
+ pkg = torch.load(checkpoint_path, 'cpu')
25
+ new_pkg = {
26
+ 'best_state': pkg['best_state']['model'],
27
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
28
+ 'version': __version__,
29
+ 'exported': True,
30
+ }
31
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
32
+ torch.save(new_pkg, out_file)
33
+ return out_file
34
+
35
+
36
+ def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
37
+ """Export a compression model (potentially EnCodec) from a pretrained model.
38
+ This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
39
+ Do not include the //pretrained/ prefix. For instance if you trained a model
40
+ with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
41
+
42
+ In that case, this will not actually include a copy of the model, simply the reference
43
+ to the model used.
44
+ """
45
+ if Path(pretrained_encodec).exists():
46
+ pkg = torch.load(pretrained_encodec)
47
+ assert 'best_state' in pkg
48
+ assert 'xp.cfg' in pkg
49
+ assert 'version' in pkg
50
+ assert 'exported' in pkg
51
+ else:
52
+ pkg = {
53
+ 'pretrained': pretrained_encodec,
54
+ 'exported': True,
55
+ 'version': __version__,
56
+ }
57
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
58
+ torch.save(pkg, out_file)
59
+
60
+
61
+ def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
62
+ """Export only the best state from the given MusicGen or AudioGen checkpoint.
63
+ """
64
+ pkg = torch.load(checkpoint_path, 'cpu')
65
+ if pkg['fsdp_best_state']:
66
+ best_state = pkg['fsdp_best_state']['model']
67
+ else:
68
+ assert pkg['best_state']
69
+ best_state = pkg['best_state']['model']
70
+ new_pkg = {
71
+ 'best_state': best_state,
72
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
73
+ 'version': __version__,
74
+ 'exported': True,
75
+ }
76
+
77
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
78
+ torch.save(new_pkg, out_file)
79
+ return out_file
audiocraft/utils/export_legacy.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Legacy functions used at the time of the first release, kept for referencd.
9
+ """
10
+
11
+ from pathlib import Path
12
+ import typing as tp
13
+
14
+ from omegaconf import OmegaConf, DictConfig
15
+ import torch
16
+
17
+ from audiocraft import __version__
18
+
19
+
20
+ def _clean_lm_cfg(cfg: DictConfig):
21
+ OmegaConf.set_struct(cfg, False)
22
+ # This used to be set automatically in the LM solver, need a more robust solution
23
+ # for the future.
24
+ cfg['transformer_lm']['card'] = 2048
25
+ n_q = 4
26
+ stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None)
27
+ if stereo_cfg is not None and stereo_cfg.use:
28
+ if 'downsample' in stereo_cfg:
29
+ del stereo_cfg['downsample']
30
+ n_q = 8
31
+ cfg['transformer_lm']['n_q'] = n_q
32
+ # Experimental params no longer supported.
33
+ bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
34
+ 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
35
+ for name in bad_params:
36
+ del cfg['transformer_lm'][name]
37
+ OmegaConf.set_struct(cfg, True)
38
+ return cfg
39
+
40
+
41
+ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
42
+ pkg = torch.load(checkpoint_path, 'cpu')
43
+ new_pkg = {
44
+ 'best_state': pkg['ema']['state']['model'],
45
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
46
+ # The following params were NOT exported for the first release of MusicGen.
47
+ 'version': __version__,
48
+ 'exported': True,
49
+ }
50
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
51
+ torch.save(new_pkg, out_file)
52
+ return out_file
53
+
54
+
55
+ def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
56
+ pkg = torch.load(checkpoint_path, 'cpu')
57
+ if pkg['fsdp_best_state']:
58
+ best_state = pkg['fsdp_best_state']['model']
59
+ else:
60
+ best_state = pkg['best_state']['model']
61
+ new_pkg = {
62
+ 'best_state': best_state,
63
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])),
64
+ # The following params were NOT exported for the first release of MusicGen.
65
+ 'version': __version__,
66
+ 'exported': True,
67
+ }
68
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
69
+ torch.save(new_pkg, out_file)
70
+ return out_file
audiocraft/utils/notebook.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ try:
8
+ import IPython.display as ipd # type: ignore
9
+ except ImportError:
10
+ # Note in a notebook...
11
+ pass
12
+
13
+
14
+ import torch
15
+
16
+
17
+ def display_audio(samples: torch.Tensor, sample_rate: int):
18
+ """Renders an audio player for the given audio samples.
19
+
20
+ Args:
21
+ samples (torch.Tensor): a Tensor of decoded audio samples
22
+ with shapes [B, C, T] or [C, T]
23
+ sample_rate (int): sample rate audio should be displayed with.
24
+ """
25
+ assert samples.dim() == 2 or samples.dim() == 3
26
+
27
+ samples = samples.detach().cpu()
28
+ if samples.dim() == 2:
29
+ samples = samples[None, ...]
30
+
31
+ for audio in samples:
32
+ ipd.display(ipd.Audio(audio, rate=sample_rate))
audiocraft/utils/profiler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import typing as tp
9
+
10
+ import dora
11
+ import torch
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Profiler:
18
+ """Context manager wrapper for xformers profiler.
19
+ """
20
+ def __init__(self, module: torch.nn.Module, enabled: bool = False):
21
+ self.profiler: tp.Optional[tp.Any] = None
22
+ if enabled:
23
+ from xformers.profiler import profile
24
+ output_dir = dora.get_xp().folder / 'profiler_data'
25
+ logger.info("Profiling activated, results with be saved to %s", output_dir)
26
+ self.profiler = profile(output_dir=output_dir, module=module)
27
+
28
+ def step(self):
29
+ if self.profiler is not None:
30
+ self.profiler.step() # type: ignore
31
+
32
+ def __enter__(self):
33
+ if self.profiler is not None:
34
+ return self.profiler.__enter__() # type: ignore
35
+
36
+ def __exit__(self, exc_type, exc_value, exc_tb):
37
+ if self.profiler is not None:
38
+ return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
audiocraft/utils/utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from concurrent.futures import ProcessPoolExecutor
8
+ from contextlib import contextmanager
9
+ from functools import wraps, lru_cache
10
+ import hashlib
11
+ import json
12
+ import logging
13
+ from pathlib import Path
14
+ import typing as tp
15
+
16
+ import flashy
17
+ import flashy.distrib
18
+ import omegaconf
19
+ import torch
20
+ from torch.nn.utils.rnn import pad_sequence
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def model_hash(model: torch.nn.Module) -> str:
27
+ """Return a model hash. This should allow us to track regressions in model init
28
+ from the logs of past experiments.
29
+ """
30
+ hasher = hashlib.sha1()
31
+ for p in model.parameters():
32
+ hasher.update(p.data.cpu().numpy().tobytes())
33
+ return hasher.hexdigest()
34
+
35
+
36
+ def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
37
+ """Convenience function to map an omegaconf configuration to a dictionary.
38
+
39
+ Args:
40
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
41
+ Returns:
42
+ dict: Config as dictionary object.
43
+ """
44
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
45
+ assert isinstance(dct, dict)
46
+ return dct
47
+
48
+
49
+ def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
50
+ if max_samples >= len(dataset):
51
+ return dataset
52
+
53
+ generator = torch.Generator().manual_seed(seed)
54
+ perm = torch.randperm(len(dataset), generator=generator)
55
+ return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
56
+
57
+
58
+ def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
59
+ num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
60
+ """Convenience function to load dataset into a dataloader with optional subset sampling.
61
+
62
+ Args:
63
+ dataset: Dataset to load.
64
+ num_samples (Optional[int]): Number of samples to limit subset size.
65
+ batch_size (int): Batch size.
66
+ num_workers (int): Number of workers for data loading.
67
+ seed (int): Random seed.
68
+ """
69
+ if num_samples is not None:
70
+ dataset = random_subset(dataset, num_samples, seed)
71
+
72
+ dataloader = flashy.distrib.loader(
73
+ dataset,
74
+ batch_size=batch_size,
75
+ num_workers=num_workers,
76
+ **kwargs
77
+ )
78
+ return dataloader
79
+
80
+
81
+ def get_dataset_from_loader(dataloader):
82
+ dataset = dataloader.dataset
83
+ if isinstance(dataset, torch.utils.data.Subset):
84
+ return dataset.dataset
85
+ else:
86
+ return dataset
87
+
88
+
89
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
90
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
91
+
92
+ Args:
93
+ input (torch.Tensor): The input tensor containing probabilities.
94
+ num_samples (int): Number of samples to draw.
95
+ replacement (bool): Whether to draw with replacement or not.
96
+ Keywords args:
97
+ generator (torch.Generator): A pseudorandom number generator for sampling.
98
+ Returns:
99
+ torch.Tensor: Last dimension contains num_samples indices
100
+ sampled from the multinomial probability distribution
101
+ located in the last dimension of tensor input.
102
+ """
103
+ input_ = input.reshape(-1, input.shape[-1])
104
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
105
+ output = output_.reshape(*list(input.shape[:-1]), -1)
106
+ return output
107
+
108
+
109
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
110
+ """Sample next token from top K values along the last dimension of the input probs tensor.
111
+
112
+ Args:
113
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
114
+ k (int): The k in “top-k”.
115
+ Returns:
116
+ torch.Tensor: Sampled tokens.
117
+ """
118
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
119
+ min_value_top_k = top_k_value[..., [-1]]
120
+ probs *= (probs >= min_value_top_k).float()
121
+ probs.div_(probs.sum(dim=-1, keepdim=True))
122
+ next_token = multinomial(probs, num_samples=1)
123
+ return next_token
124
+
125
+
126
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
127
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
128
+
129
+ Args:
130
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
131
+ p (int): The p in “top-p”.
132
+ Returns:
133
+ torch.Tensor: Sampled tokens.
134
+ """
135
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
136
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
137
+ mask = probs_sum - probs_sort > p
138
+ probs_sort *= (~mask).float()
139
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
140
+ next_token = multinomial(probs_sort, num_samples=1)
141
+ next_token = torch.gather(probs_idx, -1, next_token)
142
+ return next_token
143
+
144
+
145
+ class DummyPoolExecutor:
146
+ """Dummy pool executor to use when we actually have only 1 worker.
147
+ (e.g. instead of ProcessPoolExecutor).
148
+ """
149
+ class DummyResult:
150
+ def __init__(self, func, *args, **kwargs):
151
+ self.func = func
152
+ self.args = args
153
+ self.kwargs = kwargs
154
+
155
+ def result(self):
156
+ return self.func(*self.args, **self.kwargs)
157
+
158
+ def __init__(self, workers, mp_context=None):
159
+ pass
160
+
161
+ def submit(self, func, *args, **kwargs):
162
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
163
+
164
+ def __enter__(self):
165
+ return self
166
+
167
+ def __exit__(self, exc_type, exc_value, exc_tb):
168
+ return
169
+
170
+
171
+ def get_pool_executor(num_workers: int, mp_context=None):
172
+ return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
173
+
174
+
175
+ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
176
+ """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
177
+ For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
178
+
179
+ Args:
180
+ lengths (torch.Tensor): tensor with lengths
181
+ max_len (int): can set the max length manually. Defaults to None.
182
+ Returns:
183
+ torch.Tensor: mask with 0s where there is pad tokens else 1s
184
+ """
185
+ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
186
+ final_length = lengths.max().item() if not max_len else max_len
187
+ final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
188
+ return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
189
+
190
+
191
+ def hash_trick(word: str, vocab_size: int) -> int:
192
+ """Hash trick to pair each word with an index
193
+
194
+ Args:
195
+ word (str): word we wish to convert to an index
196
+ vocab_size (int): size of the vocabulary
197
+ Returns:
198
+ int: index of the word in the embedding LUT
199
+ """
200
+ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
201
+ return hash % vocab_size
202
+
203
+
204
+ def with_rank_rng(base_seed: int = 1234):
205
+ """Decorator for a function so that the function will use a Random Number Generator
206
+ whose state depend on the GPU rank. The original RNG state is restored upon returning.
207
+
208
+ Args:
209
+ base_seed (int): Random seed.
210
+ """
211
+ def _decorator(fun: tp.Callable):
212
+ @wraps(fun)
213
+ def _decorated(*args, **kwargs):
214
+ state = torch.get_rng_state()
215
+ seed = base_seed ^ flashy.distrib.rank()
216
+ torch.manual_seed(seed)
217
+ logger.debug('Rank dependent seed set to %d', seed)
218
+ try:
219
+ return fun(*args, **kwargs)
220
+ finally:
221
+ torch.set_rng_state(state)
222
+ logger.debug('RNG state restored.')
223
+ return _decorated
224
+ return _decorator
225
+
226
+
227
+ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
228
+ """Get a list of tensors and collate them to a single tensor. according to the following logic:
229
+ - `dim` specifies the time dimension which will be stacked and padded.
230
+ - The output will contain 1 new dimension (dimension index 0) which will be the size of
231
+ of the original list.
232
+
233
+ Args:
234
+ tensors (tp.List[torch.Tensor]): List of tensors to collate.
235
+ dim (int): Dimension which will be stacked and padded.
236
+ Returns:
237
+ tp.Tuple[torch.Tensor, torch.Tensor]:
238
+ torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
239
+ (dimension index 0) which will be the size of the original list.
240
+ torch.Tensor: Tensor containing length of original tensor sizes (without padding).
241
+ """
242
+ tensors = [x.transpose(0, dim) for x in tensors]
243
+ lens = torch.LongTensor([len(x) for x in tensors])
244
+ padded_tensors = pad_sequence(tensors)
245
+ padded_tensors = padded_tensors.transpose(0, 1)
246
+ padded_tensors = padded_tensors.transpose(1, dim + 1)
247
+ return padded_tensors, lens
248
+
249
+
250
+ # TODO: Move to flashy?
251
+ def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
252
+ dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
253
+ if isinstance(state, torch.Tensor):
254
+ if dtype is None or not state.is_floating_point():
255
+ dtype = state.dtype
256
+ return state.detach().to(device=device, dtype=dtype, copy=True)
257
+ elif isinstance(state, dict):
258
+ return {k: copy_state(v, device, dtype) for k, v in state.items()}
259
+ elif isinstance(state, list):
260
+ return [copy_state(v, device, dtype) for v in state]
261
+
262
+
263
+ # TODO: Move to flashy?
264
+ @contextmanager
265
+ def swap_state(model, state, **kwargs):
266
+ old_state = copy_state(model.state_dict())
267
+ model.load_state_dict(state, **kwargs)
268
+ try:
269
+ yield
270
+ finally:
271
+ model.load_state_dict(old_state)
272
+
273
+
274
+ @lru_cache(None)
275
+ def warn_once(logger, msg):
276
+ """Warn about a given message only once."""
277
+ logger.warning(msg)
278
+
279
+
280
+ def is_jsonable(x: tp.Any):
281
+ """Check if an object can be serialized into a json:"""
282
+ try:
283
+ json.dumps(x)
284
+ return True
285
+ except (TypeError, OverflowError):
286
+ return False
287
+
288
+
289
+ def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
290
+ """Wrapper around state dict loading of CLAP model
291
+ addressing compatibility issues between CLAP and AudioCraft
292
+ HuggingFace transformer version.
293
+ See: https://github.com/LAION-AI/CLAP/issues/118
294
+ """
295
+ from clap_module.factory import load_state_dict # type: ignore
296
+ pkg = load_state_dict(path)
297
+ pkg.pop('text_branch.embeddings.position_ids', None)
298
+ clap_model.model.load_state_dict(pkg)