unpairedelectron07
commited on
Commit
•
e5de9ff
1
Parent(s):
ee232aa
Upload manager.py
Browse files
audiocraft/utils/samples/manager.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
API that can manage the storage and retrieval of generated samples produced by experiments.
|
9 |
+
|
10 |
+
It offers the following benefits:
|
11 |
+
* Samples are stored in a consistent way across epoch
|
12 |
+
* Metadata about the samples can be stored and retrieved
|
13 |
+
* Can retrieve audio
|
14 |
+
* Identifiers are reliable and deterministic for prompted and conditioned samples
|
15 |
+
* Can request the samples for multiple XPs, grouped by sample identifier
|
16 |
+
* For no-input samples (not prompt and no conditions), samples across XPs are matched
|
17 |
+
by sorting their identifiers
|
18 |
+
"""
|
19 |
+
|
20 |
+
from concurrent.futures import ThreadPoolExecutor
|
21 |
+
from dataclasses import asdict, dataclass
|
22 |
+
from functools import lru_cache
|
23 |
+
import hashlib
|
24 |
+
import json
|
25 |
+
import logging
|
26 |
+
from pathlib import Path
|
27 |
+
import re
|
28 |
+
import typing as tp
|
29 |
+
import unicodedata
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
import dora
|
33 |
+
import torch
|
34 |
+
|
35 |
+
from ...data.audio import audio_read, audio_write
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ReferenceSample:
|
43 |
+
id: str
|
44 |
+
path: str
|
45 |
+
duration: float
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class Sample:
|
50 |
+
id: str
|
51 |
+
path: str
|
52 |
+
epoch: int
|
53 |
+
duration: float
|
54 |
+
conditioning: tp.Optional[tp.Dict[str, tp.Any]]
|
55 |
+
prompt: tp.Optional[ReferenceSample]
|
56 |
+
reference: tp.Optional[ReferenceSample]
|
57 |
+
generation_args: tp.Optional[tp.Dict[str, tp.Any]]
|
58 |
+
|
59 |
+
def __hash__(self):
|
60 |
+
return hash(self.id)
|
61 |
+
|
62 |
+
def audio(self) -> tp.Tuple[torch.Tensor, int]:
|
63 |
+
return audio_read(self.path)
|
64 |
+
|
65 |
+
def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
|
66 |
+
return audio_read(self.prompt.path) if self.prompt is not None else None
|
67 |
+
|
68 |
+
def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
|
69 |
+
return audio_read(self.reference.path) if self.reference is not None else None
|
70 |
+
|
71 |
+
|
72 |
+
class SampleManager:
|
73 |
+
"""Audio samples IO handling within a given dora xp.
|
74 |
+
|
75 |
+
The sample manager handles the dumping and loading logic for generated and
|
76 |
+
references samples across epochs for a given xp, providing a simple API to
|
77 |
+
store, retrieve and compare audio samples.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
|
81 |
+
where all outputs are stored and the configuration of the experiment,
|
82 |
+
which is useful to retrieve audio-related parameters.
|
83 |
+
map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
|
84 |
+
instead of generating a dedicated hash id. This is useful to allow easier comparison
|
85 |
+
with ground truth sample from the files directly without having to read the JSON metadata
|
86 |
+
to do the mapping (at the cost of potentially dumping duplicate prompts/references
|
87 |
+
depending on the task).
|
88 |
+
"""
|
89 |
+
def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
|
90 |
+
self.xp = xp
|
91 |
+
self.base_folder: Path = xp.folder / xp.cfg.generate.path
|
92 |
+
self.reference_folder = self.base_folder / 'reference'
|
93 |
+
self.map_reference_to_sample_id = map_reference_to_sample_id
|
94 |
+
self.samples: tp.List[Sample] = []
|
95 |
+
self._load_samples()
|
96 |
+
|
97 |
+
@property
|
98 |
+
def latest_epoch(self):
|
99 |
+
"""Latest epoch across all samples."""
|
100 |
+
return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
|
101 |
+
|
102 |
+
def _load_samples(self):
|
103 |
+
"""Scan the sample folder and load existing samples."""
|
104 |
+
jsons = self.base_folder.glob('**/*.json')
|
105 |
+
with ThreadPoolExecutor(6) as pool:
|
106 |
+
self.samples = list(pool.map(self._load_sample, jsons))
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
@lru_cache(2**26)
|
110 |
+
def _load_sample(json_file: Path) -> Sample:
|
111 |
+
with open(json_file, 'r') as f:
|
112 |
+
data: tp.Dict[str, tp.Any] = json.load(f)
|
113 |
+
# fetch prompt data
|
114 |
+
prompt_data = data.get('prompt')
|
115 |
+
prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
|
116 |
+
duration=prompt_data['duration']) if prompt_data else None
|
117 |
+
# fetch reference data
|
118 |
+
reference_data = data.get('reference')
|
119 |
+
reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
|
120 |
+
duration=reference_data['duration']) if reference_data else None
|
121 |
+
# build sample object
|
122 |
+
return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
|
123 |
+
prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
|
124 |
+
generation_args=data.get('generation_args'))
|
125 |
+
|
126 |
+
def _init_hash(self):
|
127 |
+
return hashlib.sha1()
|
128 |
+
|
129 |
+
def _get_tensor_id(self, tensor: torch.Tensor) -> str:
|
130 |
+
hash_id = self._init_hash()
|
131 |
+
hash_id.update(tensor.numpy().data)
|
132 |
+
return hash_id.hexdigest()
|
133 |
+
|
134 |
+
def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
|
135 |
+
conditions: tp.Optional[tp.Dict[str, str]]) -> str:
|
136 |
+
"""Computes an id for a sample given its input data.
|
137 |
+
This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
|
138 |
+
Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
index (int): Batch index, Helpful to differentiate samples from the same batch.
|
142 |
+
prompt_wav (torch.Tensor): Prompt used during generation.
|
143 |
+
conditions (dict[str, str]): Conditioning used during generation.
|
144 |
+
"""
|
145 |
+
# For totally unconditioned generations we will just use a random UUID.
|
146 |
+
# The function get_samples_for_xps will do a simple ordered match with a custom key.
|
147 |
+
if prompt_wav is None and not conditions:
|
148 |
+
return f"noinput_{uuid.uuid4().hex}"
|
149 |
+
|
150 |
+
# Human readable portion
|
151 |
+
hr_label = ""
|
152 |
+
# Create a deterministic id using hashing
|
153 |
+
hash_id = self._init_hash()
|
154 |
+
hash_id.update(f"{index}".encode())
|
155 |
+
if prompt_wav is not None:
|
156 |
+
hash_id.update(prompt_wav.numpy().data)
|
157 |
+
hr_label += "_prompted"
|
158 |
+
else:
|
159 |
+
hr_label += "_unprompted"
|
160 |
+
if conditions:
|
161 |
+
encoded_json = json.dumps(conditions, sort_keys=True).encode()
|
162 |
+
hash_id.update(encoded_json)
|
163 |
+
cond_str = "-".join([f"{key}={slugify(value)}"
|
164 |
+
for key, value in sorted(conditions.items())])
|
165 |
+
cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
|
166 |
+
cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
|
167 |
+
hr_label += f"_{cond_str}"
|
168 |
+
else:
|
169 |
+
hr_label += "_unconditioned"
|
170 |
+
|
171 |
+
return hash_id.hexdigest() + hr_label
|
172 |
+
|
173 |
+
def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
|
174 |
+
"""Stores the audio with the given stem path using the XP's configuration.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
wav (torch.Tensor): Audio to store.
|
178 |
+
stem_path (Path): Path in sample output directory with file stem to use.
|
179 |
+
overwrite (bool): When False (default), skips storing an existing audio file.
|
180 |
+
Returns:
|
181 |
+
Path: The path at which the audio is stored.
|
182 |
+
"""
|
183 |
+
existing_paths = [
|
184 |
+
path for path in stem_path.parent.glob(stem_path.stem + '.*')
|
185 |
+
if path.suffix != '.json'
|
186 |
+
]
|
187 |
+
exists = len(existing_paths) > 0
|
188 |
+
if exists and overwrite:
|
189 |
+
logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
|
190 |
+
elif exists:
|
191 |
+
return existing_paths[0]
|
192 |
+
|
193 |
+
audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
|
194 |
+
return audio_path
|
195 |
+
|
196 |
+
def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
|
197 |
+
conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
|
198 |
+
ground_truth_wav: tp.Optional[torch.Tensor] = None,
|
199 |
+
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
|
200 |
+
"""Adds a single sample.
|
201 |
+
The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
|
202 |
+
Each sample is assigned an id which is computed using the input data. In addition to the
|
203 |
+
sample itself, a json file containing associated metadata is stored next to it.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
|
207 |
+
epoch (int): current training epoch.
|
208 |
+
index (int): helpful to differentiate samples from the same batch.
|
209 |
+
conditions (dict[str, str], optional): conditioning used during generation.
|
210 |
+
prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
|
211 |
+
ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
|
212 |
+
Tensor of shape [channels, shape].
|
213 |
+
generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
|
214 |
+
Returns:
|
215 |
+
Sample: The saved sample.
|
216 |
+
"""
|
217 |
+
sample_id = self._get_sample_id(index, prompt_wav, conditions)
|
218 |
+
reuse_id = self.map_reference_to_sample_id
|
219 |
+
prompt, ground_truth = None, None
|
220 |
+
if prompt_wav is not None:
|
221 |
+
prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
|
222 |
+
prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
|
223 |
+
prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
|
224 |
+
prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
|
225 |
+
if ground_truth_wav is not None:
|
226 |
+
ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
|
227 |
+
ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
|
228 |
+
ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
|
229 |
+
ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
|
230 |
+
sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
|
231 |
+
duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
|
232 |
+
sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
|
233 |
+
self.samples.append(sample)
|
234 |
+
with open(sample_path.with_suffix('.json'), 'w') as f:
|
235 |
+
json.dump(asdict(sample), f, indent=2)
|
236 |
+
return sample
|
237 |
+
|
238 |
+
def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
|
239 |
+
conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
|
240 |
+
prompt_wavs: tp.Optional[torch.Tensor] = None,
|
241 |
+
ground_truth_wavs: tp.Optional[torch.Tensor] = None,
|
242 |
+
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
|
243 |
+
"""Adds a batch of samples.
|
244 |
+
The samples are stored in the XP's sample output directory, under a corresponding
|
245 |
+
epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
|
246 |
+
In addition to the sample itself, a json file containing associated metadata is stored next to it.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
|
250 |
+
epoch (int): Current training epoch.
|
251 |
+
conditioning (list of dict[str, str], optional): List of conditions used during generation,
|
252 |
+
one per sample in the batch.
|
253 |
+
prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
|
254 |
+
[batch_size, channels, shape].
|
255 |
+
ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
|
256 |
+
Tensor of shape [batch_size, channels, shape].
|
257 |
+
generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
|
258 |
+
Returns:
|
259 |
+
samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
|
260 |
+
"""
|
261 |
+
samples = []
|
262 |
+
for idx, wav in enumerate(samples_wavs):
|
263 |
+
prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
|
264 |
+
gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
|
265 |
+
conditions = conditioning[idx] if conditioning is not None else None
|
266 |
+
samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
|
267 |
+
return samples
|
268 |
+
|
269 |
+
def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
|
270 |
+
exclude_unprompted: bool = False, exclude_conditioned: bool = False,
|
271 |
+
exclude_unconditioned: bool = False) -> tp.Set[Sample]:
|
272 |
+
"""Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
|
273 |
+
Please note that existing samples are loaded during the manager's initialization, and added samples through this
|
274 |
+
manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
|
275 |
+
is the only way detect them.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
epoch (int): If provided, only return samples corresponding to this epoch.
|
279 |
+
max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
|
280 |
+
exclude_prompted (bool): If True, does not include samples that used a prompt.
|
281 |
+
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
|
282 |
+
exclude_conditioned (bool): If True, excludes samples that used conditioning.
|
283 |
+
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
|
284 |
+
Returns:
|
285 |
+
Samples (set of Sample): The retrieved samples matching the provided filters.
|
286 |
+
"""
|
287 |
+
if max_epoch >= 0:
|
288 |
+
samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
|
289 |
+
else:
|
290 |
+
samples_epoch = self.latest_epoch if epoch < 0 else epoch
|
291 |
+
samples = {
|
292 |
+
sample
|
293 |
+
for sample in self.samples
|
294 |
+
if (
|
295 |
+
(sample.epoch == samples_epoch) and
|
296 |
+
(not exclude_prompted or sample.prompt is None) and
|
297 |
+
(not exclude_unprompted or sample.prompt is not None) and
|
298 |
+
(not exclude_conditioned or not sample.conditioning) and
|
299 |
+
(not exclude_unconditioned or sample.conditioning)
|
300 |
+
)
|
301 |
+
}
|
302 |
+
return samples
|
303 |
+
|
304 |
+
|
305 |
+
def slugify(value: tp.Any, allow_unicode: bool = False):
|
306 |
+
"""Process string for safer file naming.
|
307 |
+
|
308 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
309 |
+
|
310 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
311 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
312 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
313 |
+
trailing whitespace, dashes, and underscores.
|
314 |
+
"""
|
315 |
+
value = str(value)
|
316 |
+
if allow_unicode:
|
317 |
+
value = unicodedata.normalize("NFKC", value)
|
318 |
+
else:
|
319 |
+
value = (
|
320 |
+
unicodedata.normalize("NFKD", value)
|
321 |
+
.encode("ascii", "ignore")
|
322 |
+
.decode("ascii")
|
323 |
+
)
|
324 |
+
value = re.sub(r"[^\w\s-]", "", value.lower())
|
325 |
+
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
326 |
+
|
327 |
+
|
328 |
+
def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
|
329 |
+
# Create a dictionary of stable id -> sample per XP
|
330 |
+
stable_samples_per_xp = [{
|
331 |
+
sample.id: sample for sample in samples
|
332 |
+
if sample.prompt is not None or sample.conditioning
|
333 |
+
} for samples in samples_per_xp]
|
334 |
+
# Set of all stable ids
|
335 |
+
stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
|
336 |
+
# Dictionary of stable id -> list of samples. If an XP does not have it, assign None
|
337 |
+
stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
|
338 |
+
# Filter out ids that contain None values (we only want matched samples after all)
|
339 |
+
# cast is necessary to avoid mypy linter errors.
|
340 |
+
return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
|
341 |
+
|
342 |
+
|
343 |
+
def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
|
344 |
+
# For unstable ids, we use a sorted list since we'll match them in order
|
345 |
+
unstable_samples_per_xp = [[
|
346 |
+
sample for sample in sorted(samples, key=lambda x: x.id)
|
347 |
+
if sample.prompt is None and not sample.conditioning
|
348 |
+
] for samples in samples_per_xp]
|
349 |
+
# Trim samples per xp so all samples can have a match
|
350 |
+
min_len = min([len(samples) for samples in unstable_samples_per_xp])
|
351 |
+
unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
|
352 |
+
# Dictionary of index -> list of matched samples
|
353 |
+
return {
|
354 |
+
f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
|
355 |
+
}
|
356 |
+
|
357 |
+
|
358 |
+
def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
|
359 |
+
"""Gets a dictionary of matched samples across the given XPs.
|
360 |
+
Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
|
361 |
+
will always match the number of XPs provided and will correspond to each XP in the same order given.
|
362 |
+
In other words, only samples that can be match across all provided XPs will be returned
|
363 |
+
in order to satisfy this rule.
|
364 |
+
|
365 |
+
There are two types of ids that can be returned: stable and unstable.
|
366 |
+
* Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
|
367 |
+
(prompts/conditioning). This is why we can match them across XPs.
|
368 |
+
* Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
|
369 |
+
that used non-deterministic, random ids. This is the case for samples that did not use prompts or
|
370 |
+
conditioning for their generation. This function will sort these samples by their id and match them
|
371 |
+
by their index.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
xps: a list of XPs to match samples from.
|
375 |
+
start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
|
376 |
+
end_epoch (int): If provided, only return samples corresponding to this epoch or older.
|
377 |
+
exclude_prompted (bool): If True, does not include samples that used a prompt.
|
378 |
+
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
|
379 |
+
exclude_conditioned (bool): If True, excludes samples that used conditioning.
|
380 |
+
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
|
381 |
+
"""
|
382 |
+
managers = [SampleManager(xp) for xp in xps]
|
383 |
+
samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
|
384 |
+
stable_samples = _match_stable_samples(samples_per_xp)
|
385 |
+
unstable_samples = _match_unstable_samples(samples_per_xp)
|
386 |
+
return dict(stable_samples, **unstable_samples)
|