unpairedelectron07
commited on
Upload 11 files
Browse files- audiocraft/utils/autocast.py +40 -0
- audiocraft/utils/best_state.py +81 -0
- audiocraft/utils/cache.py +324 -0
- audiocraft/utils/checkpoint.py +161 -0
- audiocraft/utils/cluster.py +75 -0
- audiocraft/utils/deadlock.py +58 -0
- audiocraft/utils/export.py +79 -0
- audiocraft/utils/export_legacy.py +70 -0
- audiocraft/utils/notebook.py +32 -0
- audiocraft/utils/profiler.py +38 -0
- audiocraft/utils/utils.py +298 -0
audiocraft/utils/autocast.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class TorchAutocast:
|
11 |
+
"""TorchAutocast utility class.
|
12 |
+
Allows you to enable and disable autocast. This is specially useful
|
13 |
+
when dealing with different architectures and clusters with different
|
14 |
+
levels of support.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
enabled (bool): Whether to enable torch.autocast or not.
|
18 |
+
args: Additional args for torch.autocast.
|
19 |
+
kwargs: Additional kwargs for torch.autocast
|
20 |
+
"""
|
21 |
+
def __init__(self, enabled: bool, *args, **kwargs):
|
22 |
+
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
23 |
+
|
24 |
+
def __enter__(self):
|
25 |
+
if self.autocast is None:
|
26 |
+
return
|
27 |
+
try:
|
28 |
+
self.autocast.__enter__()
|
29 |
+
except RuntimeError:
|
30 |
+
device = self.autocast.device
|
31 |
+
dtype = self.autocast.fast_dtype
|
32 |
+
raise RuntimeError(
|
33 |
+
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
34 |
+
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
35 |
+
)
|
36 |
+
|
37 |
+
def __exit__(self, *args, **kwargs):
|
38 |
+
if self.autocast is None:
|
39 |
+
return
|
40 |
+
self.autocast.__exit__(*args, **kwargs)
|
audiocraft/utils/best_state.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
import logging
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import flashy
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from ..optim import ModuleDictEMA
|
15 |
+
from .utils import copy_state
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class BestStateDictManager(flashy.state.StateDictSource):
|
22 |
+
"""BestStateDictManager maintains a copy of best state_dict() for registered sources.
|
23 |
+
|
24 |
+
BestStateDictManager has two main attributes:
|
25 |
+
states (dict): State dict of the registered StateDictSource.
|
26 |
+
param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
|
27 |
+
|
28 |
+
When registering new sources, the BestStateDictManager will ensure two conflicting sources between
|
29 |
+
ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
|
30 |
+
what to consider for best state.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
device (torch.device or str): Device on which we keep the copy.
|
34 |
+
dtype (torch.dtype): Data type for the state parameters.
|
35 |
+
"""
|
36 |
+
def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
|
37 |
+
dtype: tp.Optional[torch.dtype] = None):
|
38 |
+
self.device = device
|
39 |
+
self.states: dict = {}
|
40 |
+
self.param_ids: dict = defaultdict(dict)
|
41 |
+
self.dtype = dtype
|
42 |
+
|
43 |
+
def _get_parameter_ids(self, state_dict):
|
44 |
+
return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
|
45 |
+
|
46 |
+
def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
|
47 |
+
for registered_name, registered_param_ids in self.param_ids.items():
|
48 |
+
if registered_name != name:
|
49 |
+
overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
|
50 |
+
assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
|
51 |
+
f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
|
52 |
+
|
53 |
+
def update(self, name: str, source: flashy.state.StateDictSource):
|
54 |
+
if name not in self.states:
|
55 |
+
raise ValueError(f"{name} missing from registered states.")
|
56 |
+
self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
|
57 |
+
|
58 |
+
def register(self, name: str, source: flashy.state.StateDictSource):
|
59 |
+
if name in self.states:
|
60 |
+
raise ValueError(f"{name} already present in states.")
|
61 |
+
# Registering parameter ids for EMA and non-EMA states allows us to check that
|
62 |
+
# there is no overlap that would create ambiguity about how to handle the best state
|
63 |
+
param_ids = self._get_parameter_ids(source.state_dict())
|
64 |
+
if isinstance(source, ModuleDictEMA):
|
65 |
+
logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
|
66 |
+
self._validate_no_parameter_ids_overlap(name, param_ids)
|
67 |
+
self.param_ids[name] = param_ids
|
68 |
+
else:
|
69 |
+
logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
|
70 |
+
self._validate_no_parameter_ids_overlap('base', param_ids)
|
71 |
+
self.param_ids['base'].update(param_ids)
|
72 |
+
# Register state
|
73 |
+
self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
|
74 |
+
|
75 |
+
def state_dict(self) -> flashy.state.StateDict:
|
76 |
+
return self.states
|
77 |
+
|
78 |
+
def load_state_dict(self, state: flashy.state.StateDict):
|
79 |
+
for name, sub_state in state.items():
|
80 |
+
for k, v in sub_state.items():
|
81 |
+
self.states[name][k].copy_(v)
|
audiocraft/utils/cache.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from collections import deque
|
9 |
+
from functools import partial
|
10 |
+
from hashlib import sha1
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
import sys
|
14 |
+
import typing as tp
|
15 |
+
import zipfile
|
16 |
+
|
17 |
+
import flashy
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
|
25 |
+
"""Utility function for the EmbeddingCache, returning the full embedding without any chunking.
|
26 |
+
This method can be used in case there is no need in extracting a chunk of the full embedding
|
27 |
+
read from the cache.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
full_embed (torch.Tensor): The full embedding.
|
31 |
+
x (any): Batch object from which the full embedding is derived.
|
32 |
+
idx (torch.Tensor): Index of object to consider in the batch object.
|
33 |
+
Returns:
|
34 |
+
full_embed (torch.Tensor): The full embedding
|
35 |
+
"""
|
36 |
+
return full_embed.to(device)
|
37 |
+
|
38 |
+
|
39 |
+
class EmbeddingCache:
|
40 |
+
"""Cache around embeddings computation for faster execution.
|
41 |
+
The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
|
42 |
+
to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
|
43 |
+
using a user-provided function. When the cache is warm (all embeddings are pre-computed),
|
44 |
+
the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
|
45 |
+
Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
|
46 |
+
and synchronization points in the forward calls.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
|
50 |
+
device (str or torch.device): Device on which the embedding is returned.
|
51 |
+
compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
|
52 |
+
the embedding from a given object and path. This user provided function can compute the
|
53 |
+
embedding from the provided object or using the provided path as entry point. The last parameter
|
54 |
+
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
55 |
+
extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
|
56 |
+
the desired embedding chunk from the full embedding loaded from the cache. The last parameter
|
57 |
+
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
58 |
+
If not specified, will return the full embedding unmodified.
|
59 |
+
"""
|
60 |
+
def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
|
61 |
+
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
|
62 |
+
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
|
63 |
+
self.cache_path = Path(cache_path)
|
64 |
+
self.device = device
|
65 |
+
self._compute_embed_fn = compute_embed_fn
|
66 |
+
self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
|
67 |
+
if extract_embed_fn is not None:
|
68 |
+
self._extract_embed_fn = extract_embed_fn
|
69 |
+
else:
|
70 |
+
self._extract_embed_fn = partial(get_full_embed, device=device)
|
71 |
+
if self.cache_path is not None:
|
72 |
+
self.cache_path.mkdir(exist_ok=True, parents=True)
|
73 |
+
logger.info(f"Cache instantiated at: {self.cache_path}")
|
74 |
+
self.pool = ThreadPoolExecutor(8)
|
75 |
+
self.pool.__enter__()
|
76 |
+
self._current_batch_cache: dict = {}
|
77 |
+
self._memory_cache: dict = {}
|
78 |
+
|
79 |
+
def _get_cache_path(self, path: tp.Union[Path, str]):
|
80 |
+
"""Get cache path for the given file path."""
|
81 |
+
sig = sha1(str(path).encode()).hexdigest()
|
82 |
+
return self.cache_path / sig
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def _get_full_embed_from_cache(cache: Path):
|
86 |
+
"""Loads full pre-computed embedding from the cache."""
|
87 |
+
try:
|
88 |
+
embed = torch.load(cache, 'cpu')
|
89 |
+
except Exception as exc:
|
90 |
+
logger.error("Error loading %s: %r", cache, exc)
|
91 |
+
embed = None
|
92 |
+
return embed
|
93 |
+
|
94 |
+
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
|
95 |
+
"""Get embedding from cache, computing and storing it to cache if not already cached.
|
96 |
+
The EmbeddingCache first tries to load the embedding from the in-memory cache
|
97 |
+
containing the pre-computed chunks populated through `populate_embed_cache`.
|
98 |
+
If not found, the full embedding is computed and stored on disk to be later accessed
|
99 |
+
to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
paths (list[Path or str]): List of paths from where the embeddings can be loaded.
|
103 |
+
x (any): Object from which the embedding is extracted.
|
104 |
+
"""
|
105 |
+
embeds = []
|
106 |
+
for idx, path in enumerate(paths):
|
107 |
+
cache = self._get_cache_path(path)
|
108 |
+
if cache in self._current_batch_cache:
|
109 |
+
embed = self._current_batch_cache[cache]
|
110 |
+
else:
|
111 |
+
full_embed = self._compute_embed_fn(path, x, idx)
|
112 |
+
try:
|
113 |
+
with flashy.utils.write_and_rename(cache, pid=True) as f:
|
114 |
+
torch.save(full_embed.cpu(), f)
|
115 |
+
except Exception as exc:
|
116 |
+
logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
|
117 |
+
else:
|
118 |
+
logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
|
119 |
+
embed = self._extract_embed_fn(full_embed, x, idx)
|
120 |
+
embeds.append(embed)
|
121 |
+
embed = torch.stack(embeds, dim=0)
|
122 |
+
return embed
|
123 |
+
|
124 |
+
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
|
125 |
+
"""Populate in-memory caches for embeddings reading from the embeddings stored on disk.
|
126 |
+
The in-memory caches consist in a cache for the full embedding and another cache for the
|
127 |
+
final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
|
128 |
+
and reduce the IO footprint and synchronization points during forward passes.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
paths (list[Path]): List of paths from where the embeddings can be loaded.
|
132 |
+
x (any): Object from which the embedding is extracted.
|
133 |
+
"""
|
134 |
+
self._current_batch_cache.clear()
|
135 |
+
if self.cache_path is not None:
|
136 |
+
futures: list = []
|
137 |
+
for path in paths:
|
138 |
+
assert path is not None, "Path is required for computation from cache"
|
139 |
+
cache = self._get_cache_path(path)
|
140 |
+
if cache in self._memory_cache or not cache.exists():
|
141 |
+
futures.append(None)
|
142 |
+
else:
|
143 |
+
futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
|
144 |
+
for idx, (path, future) in enumerate(zip(paths, futures)):
|
145 |
+
assert path is not None
|
146 |
+
cache = self._get_cache_path(path)
|
147 |
+
full_embed = None
|
148 |
+
if future is None:
|
149 |
+
if cache in self._memory_cache:
|
150 |
+
full_embed = self._memory_cache[cache]
|
151 |
+
else:
|
152 |
+
full_embed = future.result()
|
153 |
+
if full_embed is not None:
|
154 |
+
self._memory_cache[cache] = full_embed
|
155 |
+
full_embed = full_embed.to(self.device)
|
156 |
+
if full_embed is not None:
|
157 |
+
embed = self._extract_embed_fn(full_embed, x, idx)
|
158 |
+
self._current_batch_cache[cache] = embed
|
159 |
+
|
160 |
+
|
161 |
+
class CachedBatchWriter:
|
162 |
+
"""Write pre computed caches for mini batches. This can
|
163 |
+
make loading a lot more efficient depending on your filesystem.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
cache_folder (Path): folder in which the cached minibatches
|
167 |
+
will be stored.
|
168 |
+
|
169 |
+
Inside cache folder, the structure is the following:
|
170 |
+
`epoch_number / update_number.zip`
|
171 |
+
And the zip file contains one entry per batch item.
|
172 |
+
|
173 |
+
It is possible to use the cache with a batch size smaller than
|
174 |
+
created with but obviously not larger. Make sure to call the
|
175 |
+
`start_epoch(epoch)` method for indicating changes of epochs.
|
176 |
+
|
177 |
+
See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
|
178 |
+
for an example of how to warmup the cache.
|
179 |
+
"""
|
180 |
+
def __init__(self, cache_folder: Path):
|
181 |
+
self.cache_folder = cache_folder
|
182 |
+
self._current_epoch: tp.Optional[int] = None
|
183 |
+
self._current_index = 0
|
184 |
+
|
185 |
+
def start_epoch(self, epoch: int):
|
186 |
+
"""Call at the beginning of each epoch.
|
187 |
+
"""
|
188 |
+
self._current_epoch = epoch
|
189 |
+
self._current_index = 0
|
190 |
+
self._zip_path.parent.mkdir(exist_ok=True, parents=True)
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def _get_zip_path(cache_folder: Path, epoch: int, index: int):
|
194 |
+
return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
|
195 |
+
|
196 |
+
@property
|
197 |
+
def _zip_path(self):
|
198 |
+
assert self._current_epoch is not None
|
199 |
+
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
|
200 |
+
|
201 |
+
def save(self, *content):
|
202 |
+
"""Save one mini batch. This function is distributed-aware
|
203 |
+
and will automatically merge all the items from the different
|
204 |
+
workers.
|
205 |
+
"""
|
206 |
+
all_contents = []
|
207 |
+
for rank in range(flashy.distrib.world_size()):
|
208 |
+
their_content = flashy.distrib.broadcast_object(content, src=rank)
|
209 |
+
all_contents.append(their_content)
|
210 |
+
|
211 |
+
if flashy.distrib.is_rank_zero():
|
212 |
+
idx = 0
|
213 |
+
with flashy.utils.write_and_rename(self._zip_path) as tmp:
|
214 |
+
with zipfile.ZipFile(tmp, 'w') as zf:
|
215 |
+
for content in all_contents:
|
216 |
+
for vals in zip(*content):
|
217 |
+
with zf.open(f'{idx}', 'w') as f: # type: ignore
|
218 |
+
torch.save(vals, f)
|
219 |
+
idx += 1
|
220 |
+
flashy.distrib.barrier()
|
221 |
+
self._current_index += 1
|
222 |
+
|
223 |
+
|
224 |
+
class CachedBatchLoader:
|
225 |
+
"""Loader for cached mini-batches dumped with `CachedBatchWriter`.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
cache_folder (Path): folder in which the cached minibatches are stored.
|
229 |
+
batch_size (int): batch size (per GPU) expected.
|
230 |
+
num_workers (int): number of workers to use for loading.
|
231 |
+
min_length (int): minimum expected length for each epoch. If some
|
232 |
+
mini-batches are missing, and error is raised.
|
233 |
+
|
234 |
+
This is iterable just like a regular DataLoader.
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, cache_folder: Path, batch_size: int,
|
238 |
+
num_workers: int = 10, min_length: int = 1):
|
239 |
+
self.cache_folder = cache_folder
|
240 |
+
self.batch_size = batch_size
|
241 |
+
self.num_workers = num_workers
|
242 |
+
self.min_length = min_length
|
243 |
+
self._current_epoch: tp.Optional[int] = None
|
244 |
+
self.sampler = None # for compatibility with the regular DataLoader
|
245 |
+
|
246 |
+
def __len__(self):
|
247 |
+
path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
|
248 |
+
return len([p for p in path.iterdir() if p.suffix == ".zip"])
|
249 |
+
|
250 |
+
def start_epoch(self, epoch: int):
|
251 |
+
"""Call at the beginning of each epoch.
|
252 |
+
"""
|
253 |
+
self._current_epoch = epoch
|
254 |
+
|
255 |
+
def _zip_path(self, index: int):
|
256 |
+
assert self._current_epoch is not None
|
257 |
+
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
|
258 |
+
|
259 |
+
def _load_one(self, index: int):
|
260 |
+
zip_path = self._zip_path(index)
|
261 |
+
if not zip_path.exists():
|
262 |
+
if index < self.min_length:
|
263 |
+
raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
|
264 |
+
|
265 |
+
return None
|
266 |
+
mode = "rb" if sys.version_info >= (3, 9) else "r"
|
267 |
+
try:
|
268 |
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
269 |
+
rank = flashy.distrib.rank()
|
270 |
+
world_size = flashy.distrib.world_size()
|
271 |
+
root = zipfile.Path(zf)
|
272 |
+
items = list(root.iterdir())
|
273 |
+
total_batch_size = self.batch_size * world_size
|
274 |
+
if len(items) < total_batch_size:
|
275 |
+
raise RuntimeError(
|
276 |
+
f"The cache can handle a max batch size of {len(items)}, "
|
277 |
+
f"but {total_batch_size} is needed.")
|
278 |
+
start = rank * self.batch_size
|
279 |
+
items = items[start: start + self.batch_size]
|
280 |
+
assert len(items) == self.batch_size
|
281 |
+
entries = []
|
282 |
+
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
|
283 |
+
transposed = zip(*entries)
|
284 |
+
out = []
|
285 |
+
for part in transposed:
|
286 |
+
assert len(part) > 0
|
287 |
+
if isinstance(part[0], torch.Tensor):
|
288 |
+
out.append(torch.stack(part))
|
289 |
+
else:
|
290 |
+
assert isinstance(part, torch.Tensor)
|
291 |
+
out.append(part)
|
292 |
+
return out
|
293 |
+
except Exception:
|
294 |
+
logger.error("Error when reading zip path %s", zip_path)
|
295 |
+
raise
|
296 |
+
|
297 |
+
def __iter__(self):
|
298 |
+
"""This will yields tuples, exactly as provided to the
|
299 |
+
`CachedBatchWriter.save` method.
|
300 |
+
"""
|
301 |
+
pool = ThreadPoolExecutor(self.num_workers)
|
302 |
+
next_index = 0
|
303 |
+
queue = deque()
|
304 |
+
|
305 |
+
def _get_next():
|
306 |
+
nonlocal next_index
|
307 |
+
r = queue.popleft().result()
|
308 |
+
if r is None:
|
309 |
+
return None
|
310 |
+
else:
|
311 |
+
queue.append(pool.submit(self._load_one, next_index))
|
312 |
+
next_index += 1
|
313 |
+
return r
|
314 |
+
|
315 |
+
with pool:
|
316 |
+
# fill the buffer of fetching jobs.
|
317 |
+
for _ in range(2 * self.num_workers):
|
318 |
+
queue.append(pool.submit(self._load_one, next_index))
|
319 |
+
next_index += 1
|
320 |
+
while True:
|
321 |
+
batch = _get_next()
|
322 |
+
if batch is None:
|
323 |
+
return
|
324 |
+
yield batch
|
audiocraft/utils/checkpoint.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from enum import Enum
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import re
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import flashy
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from ..environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class CheckpointSource(Enum):
|
23 |
+
CURRENT_XP = "current_xp"
|
24 |
+
PRETRAINED = "pretrained"
|
25 |
+
OTHER = "other"
|
26 |
+
|
27 |
+
|
28 |
+
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
|
29 |
+
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
|
30 |
+
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
|
31 |
+
'best' for the best checkpoint or the epoch number.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
name (str, optional): Name suffix for the checkpoint file stem.
|
35 |
+
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
|
36 |
+
use_fsdp (bool): Whether the calling solver relies on FSDP.
|
37 |
+
Returns:
|
38 |
+
str: The checkpoint name.
|
39 |
+
"""
|
40 |
+
suffix = ''
|
41 |
+
if rank is None:
|
42 |
+
rank = flashy.distrib.rank()
|
43 |
+
if rank > 0 and use_fsdp:
|
44 |
+
suffix = '.' + str(rank)
|
45 |
+
name_part = ''
|
46 |
+
if name is not None:
|
47 |
+
name_part = f'_{name}'
|
48 |
+
return f'checkpoint{name_part}.th{suffix}'
|
49 |
+
|
50 |
+
|
51 |
+
def is_sharded_checkpoint(path: Path) -> bool:
|
52 |
+
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
|
53 |
+
return re.search(r'\.th\.\d+$', path.name) is not None
|
54 |
+
|
55 |
+
|
56 |
+
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
|
57 |
+
use_fsdp: bool = False) -> tp.Optional[Path]:
|
58 |
+
"""Resolve a given checkpoint path for a provided dora sig or path.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
sig_or_path (Path or str): Checkpoint path or dora signature.
|
62 |
+
name (str, optional): Name suffix for the checkpoint file stem.
|
63 |
+
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
|
64 |
+
use_fsdp (bool): Whether the calling solver relies on FSDP.
|
65 |
+
Returns:
|
66 |
+
Path, optional: Resolved checkpoint path, if it exists.
|
67 |
+
"""
|
68 |
+
from audiocraft import train
|
69 |
+
xps_root = train.main.dora.dir / 'xps'
|
70 |
+
sig_or_path = str(sig_or_path)
|
71 |
+
if sig_or_path.startswith('//sig/'):
|
72 |
+
sig = sig_or_path[len('//sig/'):]
|
73 |
+
path = xps_root / sig
|
74 |
+
else:
|
75 |
+
path = Path(sig_or_path)
|
76 |
+
path = AudioCraftEnvironment.resolve_reference_path(path)
|
77 |
+
|
78 |
+
if path.is_dir():
|
79 |
+
path = path / checkpoint_name(name, use_fsdp=use_fsdp)
|
80 |
+
|
81 |
+
if path.exists():
|
82 |
+
return path
|
83 |
+
else:
|
84 |
+
return None
|
85 |
+
|
86 |
+
|
87 |
+
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
|
88 |
+
"""Load state from checkpoints at the specified checkpoint path."""
|
89 |
+
if is_sharded:
|
90 |
+
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
|
91 |
+
if rank0_checkpoint_path.exists():
|
92 |
+
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
|
93 |
+
state = torch.load(checkpoint_path, 'cpu')
|
94 |
+
logger.info("Checkpoint loaded from %s", checkpoint_path)
|
95 |
+
return state
|
96 |
+
|
97 |
+
|
98 |
+
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
|
99 |
+
"""Save state to disk to the specified checkpoint_path."""
|
100 |
+
_safe_save_checkpoint(state, checkpoint_path, is_sharded)
|
101 |
+
logger.info("Checkpoint saved to %s", checkpoint_path)
|
102 |
+
|
103 |
+
|
104 |
+
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
|
105 |
+
"""Flush checkpoints to only keep last N checkpoints."""
|
106 |
+
if keep_last is None or keep_last <= 0:
|
107 |
+
return
|
108 |
+
checkpoint_dir = checkpoint_path.parent
|
109 |
+
suffix = ''
|
110 |
+
if flashy.distrib.rank() > 0:
|
111 |
+
suffix = f'.{flashy.distrib.rank()}'
|
112 |
+
checkpoint_files_with_epoch = []
|
113 |
+
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
|
114 |
+
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
|
115 |
+
if epoch_part.isdigit():
|
116 |
+
checkpoint_files_with_epoch.append((path, int(epoch_part)))
|
117 |
+
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
|
118 |
+
total_to_flush = max(0, len(checkpoint_files) - keep_last)
|
119 |
+
files_to_flush = checkpoint_files[:total_to_flush]
|
120 |
+
for path in files_to_flush:
|
121 |
+
logger.debug("Removing checkpoint: %s", str(path))
|
122 |
+
path.unlink(missing_ok=True)
|
123 |
+
|
124 |
+
|
125 |
+
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
|
126 |
+
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
|
127 |
+
# Finish the work of a previous run that got interrupted while dumping.
|
128 |
+
old_path = Path(str(checkpoint_path) + '.old')
|
129 |
+
if old_path.exists():
|
130 |
+
raise RuntimeError(
|
131 |
+
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
|
132 |
+
token = Path(str(rank0_checkpoint_path) + '.tmp.done')
|
133 |
+
tmp_path = Path(str(checkpoint_path) + '.tmp')
|
134 |
+
if token.exists():
|
135 |
+
if tmp_path.exists():
|
136 |
+
tmp_path.rename(checkpoint_path)
|
137 |
+
flashy.distrib.barrier()
|
138 |
+
if flashy.distrib.is_rank_zero() and token.exists():
|
139 |
+
token.unlink()
|
140 |
+
|
141 |
+
|
142 |
+
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
|
143 |
+
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
|
144 |
+
def _barrier_if_sharded():
|
145 |
+
if is_sharded:
|
146 |
+
flashy.distrib.barrier()
|
147 |
+
|
148 |
+
if flashy.distrib.is_rank_zero():
|
149 |
+
token = Path(str(checkpoint_path) + '.tmp.done')
|
150 |
+
if token.exists():
|
151 |
+
token.unlink()
|
152 |
+
_barrier_if_sharded()
|
153 |
+
with flashy.utils.write_and_rename(checkpoint_path) as f:
|
154 |
+
torch.save(state, f)
|
155 |
+
_barrier_if_sharded()
|
156 |
+
if flashy.distrib.is_rank_zero():
|
157 |
+
token.touch()
|
158 |
+
_barrier_if_sharded()
|
159 |
+
_barrier_if_sharded()
|
160 |
+
if flashy.distrib.rank() == 0:
|
161 |
+
token.unlink()
|
audiocraft/utils/cluster.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility functions for SLURM configuration and cluster settings.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from enum import Enum
|
12 |
+
import os
|
13 |
+
import socket
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import omegaconf
|
17 |
+
|
18 |
+
|
19 |
+
class ClusterType(Enum):
|
20 |
+
AWS = "aws"
|
21 |
+
FAIR = "fair"
|
22 |
+
RSC = "rsc"
|
23 |
+
LOCAL_DARWIN = "darwin"
|
24 |
+
DEFAULT = "default" # used for any other cluster.
|
25 |
+
|
26 |
+
|
27 |
+
def _guess_cluster_type() -> ClusterType:
|
28 |
+
uname = os.uname()
|
29 |
+
fqdn = socket.getfqdn()
|
30 |
+
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
|
31 |
+
return ClusterType.AWS
|
32 |
+
|
33 |
+
if fqdn.endswith(".fair"):
|
34 |
+
return ClusterType.FAIR
|
35 |
+
|
36 |
+
if fqdn.endswith(".facebook.com"):
|
37 |
+
return ClusterType.RSC
|
38 |
+
|
39 |
+
if uname.sysname == "Darwin":
|
40 |
+
return ClusterType.LOCAL_DARWIN
|
41 |
+
|
42 |
+
return ClusterType.DEFAULT
|
43 |
+
|
44 |
+
|
45 |
+
def get_cluster_type(
|
46 |
+
cluster_type: tp.Optional[ClusterType] = None,
|
47 |
+
) -> tp.Optional[ClusterType]:
|
48 |
+
if cluster_type is None:
|
49 |
+
return _guess_cluster_type()
|
50 |
+
|
51 |
+
return cluster_type
|
52 |
+
|
53 |
+
|
54 |
+
def get_slurm_parameters(
|
55 |
+
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
|
56 |
+
) -> omegaconf.DictConfig:
|
57 |
+
"""Update SLURM parameters in configuration based on cluster type.
|
58 |
+
If the cluster type is not specify, it infers it automatically.
|
59 |
+
"""
|
60 |
+
from ..environment import AudioCraftEnvironment
|
61 |
+
cluster_type = get_cluster_type(cluster_type)
|
62 |
+
# apply cluster-specific adjustments
|
63 |
+
if cluster_type == ClusterType.AWS:
|
64 |
+
cfg["mem_per_gpu"] = None
|
65 |
+
cfg["constraint"] = None
|
66 |
+
cfg["setup"] = []
|
67 |
+
elif cluster_type == ClusterType.RSC:
|
68 |
+
cfg["mem_per_gpu"] = None
|
69 |
+
cfg["setup"] = []
|
70 |
+
cfg["constraint"] = None
|
71 |
+
cfg["partition"] = "learn"
|
72 |
+
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
|
73 |
+
if slurm_exclude is not None:
|
74 |
+
cfg["exclude"] = slurm_exclude
|
75 |
+
return cfg
|
audiocraft/utils/deadlock.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
from queue import Queue, Empty
|
10 |
+
import signal
|
11 |
+
import sys
|
12 |
+
import threading
|
13 |
+
import traceback
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class DeadlockDetect:
|
19 |
+
def __init__(self, use: bool = False, timeout: float = 120.):
|
20 |
+
self.use = use
|
21 |
+
self.timeout = timeout
|
22 |
+
self._queue: Queue = Queue()
|
23 |
+
|
24 |
+
def update(self, stage: str):
|
25 |
+
if self.use:
|
26 |
+
self._queue.put(stage)
|
27 |
+
|
28 |
+
def __enter__(self):
|
29 |
+
if self.use:
|
30 |
+
self._thread = threading.Thread(target=self._detector_thread)
|
31 |
+
self._thread.start()
|
32 |
+
|
33 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
34 |
+
if self.use:
|
35 |
+
self._queue.put(None)
|
36 |
+
self._thread.join()
|
37 |
+
|
38 |
+
def _detector_thread(self):
|
39 |
+
logger.debug("Deadlock detector started")
|
40 |
+
last_stage = "init"
|
41 |
+
while True:
|
42 |
+
try:
|
43 |
+
stage = self._queue.get(timeout=self.timeout)
|
44 |
+
except Empty:
|
45 |
+
break
|
46 |
+
if stage is None:
|
47 |
+
logger.debug("Exiting deadlock detector thread")
|
48 |
+
return
|
49 |
+
else:
|
50 |
+
last_stage = stage
|
51 |
+
logger.error("Deadlock detector timed out, last stage was %s", last_stage)
|
52 |
+
for th in threading.enumerate():
|
53 |
+
print(th, file=sys.stderr)
|
54 |
+
traceback.print_stack(sys._current_frames()[th.ident])
|
55 |
+
print(file=sys.stderr)
|
56 |
+
sys.stdout.flush()
|
57 |
+
sys.stderr.flush()
|
58 |
+
os.kill(os.getpid(), signal.SIGKILL)
|
audiocraft/utils/export.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility to export a training checkpoint to a lightweight release checkpoint.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from audiocraft import __version__
|
18 |
+
|
19 |
+
|
20 |
+
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
21 |
+
"""Export only the best state from the given EnCodec checkpoint. This
|
22 |
+
should be used if you trained your own EnCodec model.
|
23 |
+
"""
|
24 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
25 |
+
new_pkg = {
|
26 |
+
'best_state': pkg['best_state']['model'],
|
27 |
+
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
28 |
+
'version': __version__,
|
29 |
+
'exported': True,
|
30 |
+
}
|
31 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
32 |
+
torch.save(new_pkg, out_file)
|
33 |
+
return out_file
|
34 |
+
|
35 |
+
|
36 |
+
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
|
37 |
+
"""Export a compression model (potentially EnCodec) from a pretrained model.
|
38 |
+
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
|
39 |
+
Do not include the //pretrained/ prefix. For instance if you trained a model
|
40 |
+
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
|
41 |
+
|
42 |
+
In that case, this will not actually include a copy of the model, simply the reference
|
43 |
+
to the model used.
|
44 |
+
"""
|
45 |
+
if Path(pretrained_encodec).exists():
|
46 |
+
pkg = torch.load(pretrained_encodec)
|
47 |
+
assert 'best_state' in pkg
|
48 |
+
assert 'xp.cfg' in pkg
|
49 |
+
assert 'version' in pkg
|
50 |
+
assert 'exported' in pkg
|
51 |
+
else:
|
52 |
+
pkg = {
|
53 |
+
'pretrained': pretrained_encodec,
|
54 |
+
'exported': True,
|
55 |
+
'version': __version__,
|
56 |
+
}
|
57 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
58 |
+
torch.save(pkg, out_file)
|
59 |
+
|
60 |
+
|
61 |
+
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
62 |
+
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
|
63 |
+
"""
|
64 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
65 |
+
if pkg['fsdp_best_state']:
|
66 |
+
best_state = pkg['fsdp_best_state']['model']
|
67 |
+
else:
|
68 |
+
assert pkg['best_state']
|
69 |
+
best_state = pkg['best_state']['model']
|
70 |
+
new_pkg = {
|
71 |
+
'best_state': best_state,
|
72 |
+
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
73 |
+
'version': __version__,
|
74 |
+
'exported': True,
|
75 |
+
}
|
76 |
+
|
77 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
78 |
+
torch.save(new_pkg, out_file)
|
79 |
+
return out_file
|
audiocraft/utils/export_legacy.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Legacy functions used at the time of the first release, kept for referencd.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf, DictConfig
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from audiocraft import __version__
|
18 |
+
|
19 |
+
|
20 |
+
def _clean_lm_cfg(cfg: DictConfig):
|
21 |
+
OmegaConf.set_struct(cfg, False)
|
22 |
+
# This used to be set automatically in the LM solver, need a more robust solution
|
23 |
+
# for the future.
|
24 |
+
cfg['transformer_lm']['card'] = 2048
|
25 |
+
n_q = 4
|
26 |
+
stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None)
|
27 |
+
if stereo_cfg is not None and stereo_cfg.use:
|
28 |
+
if 'downsample' in stereo_cfg:
|
29 |
+
del stereo_cfg['downsample']
|
30 |
+
n_q = 8
|
31 |
+
cfg['transformer_lm']['n_q'] = n_q
|
32 |
+
# Experimental params no longer supported.
|
33 |
+
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
34 |
+
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
35 |
+
for name in bad_params:
|
36 |
+
del cfg['transformer_lm'][name]
|
37 |
+
OmegaConf.set_struct(cfg, True)
|
38 |
+
return cfg
|
39 |
+
|
40 |
+
|
41 |
+
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
42 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
43 |
+
new_pkg = {
|
44 |
+
'best_state': pkg['ema']['state']['model'],
|
45 |
+
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
46 |
+
# The following params were NOT exported for the first release of MusicGen.
|
47 |
+
'version': __version__,
|
48 |
+
'exported': True,
|
49 |
+
}
|
50 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
51 |
+
torch.save(new_pkg, out_file)
|
52 |
+
return out_file
|
53 |
+
|
54 |
+
|
55 |
+
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
56 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
57 |
+
if pkg['fsdp_best_state']:
|
58 |
+
best_state = pkg['fsdp_best_state']['model']
|
59 |
+
else:
|
60 |
+
best_state = pkg['best_state']['model']
|
61 |
+
new_pkg = {
|
62 |
+
'best_state': best_state,
|
63 |
+
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])),
|
64 |
+
# The following params were NOT exported for the first release of MusicGen.
|
65 |
+
'version': __version__,
|
66 |
+
'exported': True,
|
67 |
+
}
|
68 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
69 |
+
torch.save(new_pkg, out_file)
|
70 |
+
return out_file
|
audiocraft/utils/notebook.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
try:
|
8 |
+
import IPython.display as ipd # type: ignore
|
9 |
+
except ImportError:
|
10 |
+
# Note in a notebook...
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
def display_audio(samples: torch.Tensor, sample_rate: int):
|
18 |
+
"""Renders an audio player for the given audio samples.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
samples (torch.Tensor): a Tensor of decoded audio samples
|
22 |
+
with shapes [B, C, T] or [C, T]
|
23 |
+
sample_rate (int): sample rate audio should be displayed with.
|
24 |
+
"""
|
25 |
+
assert samples.dim() == 2 or samples.dim() == 3
|
26 |
+
|
27 |
+
samples = samples.detach().cpu()
|
28 |
+
if samples.dim() == 2:
|
29 |
+
samples = samples[None, ...]
|
30 |
+
|
31 |
+
for audio in samples:
|
32 |
+
ipd.display(ipd.Audio(audio, rate=sample_rate))
|
audiocraft/utils/profiler.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import dora
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class Profiler:
|
18 |
+
"""Context manager wrapper for xformers profiler.
|
19 |
+
"""
|
20 |
+
def __init__(self, module: torch.nn.Module, enabled: bool = False):
|
21 |
+
self.profiler: tp.Optional[tp.Any] = None
|
22 |
+
if enabled:
|
23 |
+
from xformers.profiler import profile
|
24 |
+
output_dir = dora.get_xp().folder / 'profiler_data'
|
25 |
+
logger.info("Profiling activated, results with be saved to %s", output_dir)
|
26 |
+
self.profiler = profile(output_dir=output_dir, module=module)
|
27 |
+
|
28 |
+
def step(self):
|
29 |
+
if self.profiler is not None:
|
30 |
+
self.profiler.step() # type: ignore
|
31 |
+
|
32 |
+
def __enter__(self):
|
33 |
+
if self.profiler is not None:
|
34 |
+
return self.profiler.__enter__() # type: ignore
|
35 |
+
|
36 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
37 |
+
if self.profiler is not None:
|
38 |
+
return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
|
audiocraft/utils/utils.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from concurrent.futures import ProcessPoolExecutor
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from functools import wraps, lru_cache
|
10 |
+
import hashlib
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
from pathlib import Path
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import flashy
|
17 |
+
import flashy.distrib
|
18 |
+
import omegaconf
|
19 |
+
import torch
|
20 |
+
from torch.nn.utils.rnn import pad_sequence
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def model_hash(model: torch.nn.Module) -> str:
|
27 |
+
"""Return a model hash. This should allow us to track regressions in model init
|
28 |
+
from the logs of past experiments.
|
29 |
+
"""
|
30 |
+
hasher = hashlib.sha1()
|
31 |
+
for p in model.parameters():
|
32 |
+
hasher.update(p.data.cpu().numpy().tobytes())
|
33 |
+
return hasher.hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
37 |
+
"""Convenience function to map an omegaconf configuration to a dictionary.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
41 |
+
Returns:
|
42 |
+
dict: Config as dictionary object.
|
43 |
+
"""
|
44 |
+
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
45 |
+
assert isinstance(dct, dict)
|
46 |
+
return dct
|
47 |
+
|
48 |
+
|
49 |
+
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
|
50 |
+
if max_samples >= len(dataset):
|
51 |
+
return dataset
|
52 |
+
|
53 |
+
generator = torch.Generator().manual_seed(seed)
|
54 |
+
perm = torch.randperm(len(dataset), generator=generator)
|
55 |
+
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
|
56 |
+
|
57 |
+
|
58 |
+
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
|
59 |
+
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
|
60 |
+
"""Convenience function to load dataset into a dataloader with optional subset sampling.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
dataset: Dataset to load.
|
64 |
+
num_samples (Optional[int]): Number of samples to limit subset size.
|
65 |
+
batch_size (int): Batch size.
|
66 |
+
num_workers (int): Number of workers for data loading.
|
67 |
+
seed (int): Random seed.
|
68 |
+
"""
|
69 |
+
if num_samples is not None:
|
70 |
+
dataset = random_subset(dataset, num_samples, seed)
|
71 |
+
|
72 |
+
dataloader = flashy.distrib.loader(
|
73 |
+
dataset,
|
74 |
+
batch_size=batch_size,
|
75 |
+
num_workers=num_workers,
|
76 |
+
**kwargs
|
77 |
+
)
|
78 |
+
return dataloader
|
79 |
+
|
80 |
+
|
81 |
+
def get_dataset_from_loader(dataloader):
|
82 |
+
dataset = dataloader.dataset
|
83 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
84 |
+
return dataset.dataset
|
85 |
+
else:
|
86 |
+
return dataset
|
87 |
+
|
88 |
+
|
89 |
+
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
90 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
94 |
+
num_samples (int): Number of samples to draw.
|
95 |
+
replacement (bool): Whether to draw with replacement or not.
|
96 |
+
Keywords args:
|
97 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Last dimension contains num_samples indices
|
100 |
+
sampled from the multinomial probability distribution
|
101 |
+
located in the last dimension of tensor input.
|
102 |
+
"""
|
103 |
+
input_ = input.reshape(-1, input.shape[-1])
|
104 |
+
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
105 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
106 |
+
return output
|
107 |
+
|
108 |
+
|
109 |
+
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
110 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
114 |
+
k (int): The k in “top-k”.
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: Sampled tokens.
|
117 |
+
"""
|
118 |
+
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
119 |
+
min_value_top_k = top_k_value[..., [-1]]
|
120 |
+
probs *= (probs >= min_value_top_k).float()
|
121 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
122 |
+
next_token = multinomial(probs, num_samples=1)
|
123 |
+
return next_token
|
124 |
+
|
125 |
+
|
126 |
+
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
127 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
131 |
+
p (int): The p in “top-p”.
|
132 |
+
Returns:
|
133 |
+
torch.Tensor: Sampled tokens.
|
134 |
+
"""
|
135 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
136 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
137 |
+
mask = probs_sum - probs_sort > p
|
138 |
+
probs_sort *= (~mask).float()
|
139 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
140 |
+
next_token = multinomial(probs_sort, num_samples=1)
|
141 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
142 |
+
return next_token
|
143 |
+
|
144 |
+
|
145 |
+
class DummyPoolExecutor:
|
146 |
+
"""Dummy pool executor to use when we actually have only 1 worker.
|
147 |
+
(e.g. instead of ProcessPoolExecutor).
|
148 |
+
"""
|
149 |
+
class DummyResult:
|
150 |
+
def __init__(self, func, *args, **kwargs):
|
151 |
+
self.func = func
|
152 |
+
self.args = args
|
153 |
+
self.kwargs = kwargs
|
154 |
+
|
155 |
+
def result(self):
|
156 |
+
return self.func(*self.args, **self.kwargs)
|
157 |
+
|
158 |
+
def __init__(self, workers, mp_context=None):
|
159 |
+
pass
|
160 |
+
|
161 |
+
def submit(self, func, *args, **kwargs):
|
162 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
163 |
+
|
164 |
+
def __enter__(self):
|
165 |
+
return self
|
166 |
+
|
167 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
168 |
+
return
|
169 |
+
|
170 |
+
|
171 |
+
def get_pool_executor(num_workers: int, mp_context=None):
|
172 |
+
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
|
173 |
+
|
174 |
+
|
175 |
+
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
176 |
+
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
|
177 |
+
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
|
178 |
+
|
179 |
+
Args:
|
180 |
+
lengths (torch.Tensor): tensor with lengths
|
181 |
+
max_len (int): can set the max length manually. Defaults to None.
|
182 |
+
Returns:
|
183 |
+
torch.Tensor: mask with 0s where there is pad tokens else 1s
|
184 |
+
"""
|
185 |
+
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
186 |
+
final_length = lengths.max().item() if not max_len else max_len
|
187 |
+
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
188 |
+
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
189 |
+
|
190 |
+
|
191 |
+
def hash_trick(word: str, vocab_size: int) -> int:
|
192 |
+
"""Hash trick to pair each word with an index
|
193 |
+
|
194 |
+
Args:
|
195 |
+
word (str): word we wish to convert to an index
|
196 |
+
vocab_size (int): size of the vocabulary
|
197 |
+
Returns:
|
198 |
+
int: index of the word in the embedding LUT
|
199 |
+
"""
|
200 |
+
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
201 |
+
return hash % vocab_size
|
202 |
+
|
203 |
+
|
204 |
+
def with_rank_rng(base_seed: int = 1234):
|
205 |
+
"""Decorator for a function so that the function will use a Random Number Generator
|
206 |
+
whose state depend on the GPU rank. The original RNG state is restored upon returning.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
base_seed (int): Random seed.
|
210 |
+
"""
|
211 |
+
def _decorator(fun: tp.Callable):
|
212 |
+
@wraps(fun)
|
213 |
+
def _decorated(*args, **kwargs):
|
214 |
+
state = torch.get_rng_state()
|
215 |
+
seed = base_seed ^ flashy.distrib.rank()
|
216 |
+
torch.manual_seed(seed)
|
217 |
+
logger.debug('Rank dependent seed set to %d', seed)
|
218 |
+
try:
|
219 |
+
return fun(*args, **kwargs)
|
220 |
+
finally:
|
221 |
+
torch.set_rng_state(state)
|
222 |
+
logger.debug('RNG state restored.')
|
223 |
+
return _decorated
|
224 |
+
return _decorator
|
225 |
+
|
226 |
+
|
227 |
+
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
228 |
+
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
229 |
+
- `dim` specifies the time dimension which will be stacked and padded.
|
230 |
+
- The output will contain 1 new dimension (dimension index 0) which will be the size of
|
231 |
+
of the original list.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
tensors (tp.List[torch.Tensor]): List of tensors to collate.
|
235 |
+
dim (int): Dimension which will be stacked and padded.
|
236 |
+
Returns:
|
237 |
+
tp.Tuple[torch.Tensor, torch.Tensor]:
|
238 |
+
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
|
239 |
+
(dimension index 0) which will be the size of the original list.
|
240 |
+
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
|
241 |
+
"""
|
242 |
+
tensors = [x.transpose(0, dim) for x in tensors]
|
243 |
+
lens = torch.LongTensor([len(x) for x in tensors])
|
244 |
+
padded_tensors = pad_sequence(tensors)
|
245 |
+
padded_tensors = padded_tensors.transpose(0, 1)
|
246 |
+
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
247 |
+
return padded_tensors, lens
|
248 |
+
|
249 |
+
|
250 |
+
# TODO: Move to flashy?
|
251 |
+
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
|
252 |
+
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
|
253 |
+
if isinstance(state, torch.Tensor):
|
254 |
+
if dtype is None or not state.is_floating_point():
|
255 |
+
dtype = state.dtype
|
256 |
+
return state.detach().to(device=device, dtype=dtype, copy=True)
|
257 |
+
elif isinstance(state, dict):
|
258 |
+
return {k: copy_state(v, device, dtype) for k, v in state.items()}
|
259 |
+
elif isinstance(state, list):
|
260 |
+
return [copy_state(v, device, dtype) for v in state]
|
261 |
+
|
262 |
+
|
263 |
+
# TODO: Move to flashy?
|
264 |
+
@contextmanager
|
265 |
+
def swap_state(model, state, **kwargs):
|
266 |
+
old_state = copy_state(model.state_dict())
|
267 |
+
model.load_state_dict(state, **kwargs)
|
268 |
+
try:
|
269 |
+
yield
|
270 |
+
finally:
|
271 |
+
model.load_state_dict(old_state)
|
272 |
+
|
273 |
+
|
274 |
+
@lru_cache(None)
|
275 |
+
def warn_once(logger, msg):
|
276 |
+
"""Warn about a given message only once."""
|
277 |
+
logger.warning(msg)
|
278 |
+
|
279 |
+
|
280 |
+
def is_jsonable(x: tp.Any):
|
281 |
+
"""Check if an object can be serialized into a json:"""
|
282 |
+
try:
|
283 |
+
json.dumps(x)
|
284 |
+
return True
|
285 |
+
except (TypeError, OverflowError):
|
286 |
+
return False
|
287 |
+
|
288 |
+
|
289 |
+
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
290 |
+
"""Wrapper around state dict loading of CLAP model
|
291 |
+
addressing compatibility issues between CLAP and AudioCraft
|
292 |
+
HuggingFace transformer version.
|
293 |
+
See: https://github.com/LAION-AI/CLAP/issues/118
|
294 |
+
"""
|
295 |
+
from clap_module.factory import load_state_dict # type: ignore
|
296 |
+
pkg = load_state_dict(path)
|
297 |
+
pkg.pop('text_branch.embeddings.position_ids', None)
|
298 |
+
clap_model.model.load_state_dict(pkg)
|