adefossez commited on
Commit
e00df76
·
1 Parent(s): 6ec60d5

initial implem

Browse files
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +50 -8
audiocraft/models/musicgen.py CHANGED
@@ -36,10 +36,12 @@ class MusicGen:
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel):
 
40
  self.name = name
41
  self.compression_model = compression_model
42
  self.lm = lm
 
43
  self.device = next(iter(lm.parameters())).device
44
  self.generation_params: dict = {}
45
  self.set_generation_params(duration=15) # 15 seconds by default
@@ -113,11 +115,10 @@ class MusicGen:
113
  should we extend the audio each time. Larger values will mean less context is
114
  preserved, and shorter value will require extra computations.
115
  """
116
- # assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
117
- assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
118
  self.extend_stride = extend_stride
 
119
  self.generation_params = {
120
- 'max_gen_len': int(duration * self.frame_rate),
121
  'use_sampling': use_sampling,
122
  'temp': temperature,
123
  'top_k': top_k,
@@ -268,8 +269,12 @@ class MusicGen:
268
  Returns:
269
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
270
  """
 
 
 
 
271
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
272
- print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
273
 
274
  if prompt_tokens is not None:
275
  assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
@@ -279,9 +284,46 @@ class MusicGen:
279
  if progress:
280
  callback = _progress_callback
281
 
282
- # generate by sampling from LM
283
- with self.autocast:
284
- gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  # generate audio
287
  assert gen_tokens.dim() == 3
 
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
40
+ max_duration: float = 30):
41
  self.name = name
42
  self.compression_model = compression_model
43
  self.lm = lm
44
+ self.max_duration = max_duration
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params: dict = {}
47
  self.set_generation_params(duration=15) # 15 seconds by default
 
115
  should we extend the audio each time. Larger values will mean less context is
116
  preserved, and shorter value will require extra computations.
117
  """
118
+ assert extend_stride <= self.max_duration - 5, "Keep at least 5 seconds of overlap!"
 
119
  self.extend_stride = extend_stride
120
+ self.duration = duration
121
  self.generation_params = {
 
122
  'use_sampling': use_sampling,
123
  'temp': temperature,
124
  'top_k': top_k,
 
269
  Returns:
270
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
271
  """
272
+ total_gen_len = int(self.duration * self.frame_rate)
273
+
274
+ current_gen_offset = 0
275
+
276
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
277
+ print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
278
 
279
  if prompt_tokens is not None:
280
  assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
 
284
  if progress:
285
  callback = _progress_callback
286
 
287
+ if self.duration <= self.max_duration:
288
+ # generate by sampling from LM, simple case.
289
+ with self.autocast:
290
+ gen_tokens = self.lm.generate(
291
+ prompt_tokens, attributes,
292
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
293
+
294
+ else:
295
+ # now this gets a bit messier, we need to handle prompts,
296
+ # melody conditioning etc.
297
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
298
+ all_tokens = []
299
+ if prompt_tokens is not None:
300
+ all_tokens.append(prompt_tokens)
301
+
302
+ for time_offset in range(0, self.duration, self.extend_stride):
303
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
304
+ max_gen_len = int(chunk_duration * self.frame_rate)
305
+ for attr, ref_wav in zip(attributes, ref_wavs):
306
+ wav_length = ref_wav.length.item()
307
+ if wav_length == 0:
308
+ continue
309
+ # We will extend the wav periodically if it not long enough.
310
+ # we have to do it here before it is too late.
311
+ initial_position = int(time_offset * self.sample_rate)
312
+ wav_target_length = int(chunk_duration * self.sample_rate)
313
+ positions = torch.arange(initial_position,
314
+ initial_position + wav_target_length, device=self.device)
315
+ attr.wav['self_wav'] = ref_wav[:, positions % wav_length]
316
+ with self.autocast:
317
+ gen_tokens = self.lm.generate(
318
+ prompt_tokens, attributes,
319
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
320
+ stride_tokens = int(self.frame_rate * self.extend_stride)
321
+ if prompt_tokens is None:
322
+ all_tokens.append(gen_tokens)
323
+ else:
324
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
325
+ prompt_tokens = gen_tokens[:, :, stride_tokens]
326
+ gen_tokens = torch.cat(all_tokens, dim=-1)
327
 
328
  # generate audio
329
  assert gen_tokens.dim() == 3