skytnt commited on
Commit
8d6cb47
1 Parent(s): 1fd2f8b

fix duration

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -123,10 +123,10 @@ def send_msgs(msgs):
123
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
124
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
125
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
 
126
  if "large" in model_name:
127
- return gen_events // 7 + 5
128
- else:
129
- return gen_events // 15 + 5
130
 
131
 
132
  @spaces.GPU(duration=get_duration)
 
123
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
124
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
125
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
126
+ t = 1e-4*gen_events**2 + 15
127
  if "large" in model_name:
128
+ t *= 2
129
+ return t
 
130
 
131
 
132
  @spaces.GPU(duration=get_duration)