adefossez commited on
Commit
3373345
1 Parent(s): fad2862

final changes

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. tests/models/test_musicgen.py +2 -2
app.py CHANGED
@@ -25,7 +25,7 @@ from audiocraft.models import MusicGen
25
 
26
  MODEL = None # Last used model
27
  IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
28
- MAX_BATCH_SIZE = 8
29
  BATCHED_DURATION = 15
30
  INTERRUPTING = False
31
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
 
25
 
26
  MODEL = None # Last used model
27
  IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
28
+ MAX_BATCH_SIZE = 12
29
  BATCHED_DURATION = 15
30
  INTERRUPTING = False
31
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
tests/models/test_musicgen.py CHANGED
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
- mg.set_generation_params(duration=2.0, stride_extend=2.)
17
  return mg
18
 
19
  def test_base(self):
@@ -52,7 +52,7 @@ class TestSEANetModel:
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
54
  mg.max_duration = 3.
55
- mg.set_generation_params(duration=4., stride_extend=2.)
56
  wav = mg.generate(
57
  ['youpi', 'lapin dort'])
58
  assert list(wav.shape) == [2, 1, 32000 * 4]
 
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
+ mg.set_generation_params(duration=2.0, extend_stride=2.)
17
  return mg
18
 
19
  def test_base(self):
 
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
54
  mg.max_duration = 3.
55
+ mg.set_generation_params(duration=4., extend_stride=2.)
56
  wav = mg.generate(
57
  ['youpi', 'lapin dort'])
58
  assert list(wav.shape) == [2, 1, 32000 * 4]