adefossez commited on
Commit
16a7142
·
1 Parent(s): 6d70065
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +7 -2
audiocraft/models/musicgen.py CHANGED
@@ -96,7 +96,7 @@ class MusicGen:
96
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
97
  top_p: float = 0.0, temperature: float = 1.0,
98
  duration: float = 30.0, cfg_coef: float = 3.0,
99
- two_step_cfg: bool = False):
100
  """Set the generation parameters for MusicGen.
101
 
102
  Args:
@@ -109,8 +109,13 @@ class MusicGen:
109
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
110
  instead of batching together the two. This has some impact on how things
111
  are padded but seems to have little impact in practice.
 
 
 
112
  """
113
- assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
 
 
114
  self.generation_params = {
115
  'max_gen_len': int(duration * self.frame_rate),
116
  'use_sampling': use_sampling,
 
96
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
97
  top_p: float = 0.0, temperature: float = 1.0,
98
  duration: float = 30.0, cfg_coef: float = 3.0,
99
+ two_step_cfg: bool = False, extend_stride: float = 15):
100
  """Set the generation parameters for MusicGen.
101
 
102
  Args:
 
109
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
110
  instead of batching together the two. This has some impact on how things
111
  are padded but seems to have little impact in practice.
112
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
113
+ should we extend the audio each time. Larger values will mean less context is
114
+ preserved, and shorter value will require extra computations.
115
  """
116
+ # assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
117
+ assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
118
+ self.extend_stride = extend_stride
119
  self.generation_params = {
120
  'max_gen_len': int(duration * self.frame_rate),
121
  'use_sampling': use_sampling,