radames commited on
Commit
62c459e
1 Parent(s): 82e1128

fix limits

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -58,9 +58,6 @@ def predict(
58
 
59
  if melody_input:
60
  melody, sr = torchaudio.load(melody_input)
61
- melody_duration = melody.shape[-1] / sr
62
- if melody_duration < duration:
63
- raise gr.Error("The duration must be greater than the melody duration!")
64
  # sr, melody = melody_input[0], torch.from_numpy(melody_input[1]).to(MODEL.device).float().t().unsqueeze(0)
65
  if melody.dim() == 2:
66
  melody = melody[None]
@@ -69,6 +66,9 @@ def predict(
69
  melody_wavform = melody[
70
  ..., int(sr * continuation_start) : int(sr * continuation_end)
71
  ]
 
 
 
72
  output = MODEL.generate_continuation(
73
  prompt=melody_wavform,
74
  prompt_sample_rate=sr,
 
58
 
59
  if melody_input:
60
  melody, sr = torchaudio.load(melody_input)
 
 
 
61
  # sr, melody = melody_input[0], torch.from_numpy(melody_input[1]).to(MODEL.device).float().t().unsqueeze(0)
62
  if melody.dim() == 2:
63
  melody = melody[None]
 
66
  melody_wavform = melody[
67
  ..., int(sr * continuation_start) : int(sr * continuation_end)
68
  ]
69
+ melody_duration = melody_wavform.shape[-1] / sr
70
+ if duration + melody_duration > MODEL.lm.cfg.dataset.segment_duration:
71
+ raise gr.Error("Duration + continuation duration must be <= 30 seconds")
72
  output = MODEL.generate_continuation(
73
  prompt=melody_wavform,
74
  prompt_sample_rate=sr,