skytnt commited on
Commit
cb9def6
1 Parent(s): 44b2b89

try onnx again

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app_onnx.py +6 -9
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
- app_file: app.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
+ app_file: app_onnx.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
app_onnx.py CHANGED
@@ -170,9 +170,10 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr
170
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
171
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
172
  model = models[model_name]
173
- model[0].set_providers(['CUDAExecutionProvider', 'CPUExecutionProvider'])
174
- model[1].set_providers(['CUDAExecutionProvider', 'CPUExecutionProvider'])
175
  tokenizer = model[2]
 
176
  bpm = int(bpm)
177
  if time_sig == "auto":
178
  time_sig = None
@@ -426,22 +427,18 @@ if __name__ == "__main__":
426
  ]
427
  }
428
  models = {}
429
- providers = ['CPUExecutionProvider']
430
 
431
  for name, (repo_id, path, config, loras) in models_info.items():
432
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
433
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
434
- model_base = rt.InferenceSession(model_base_path, providers=providers)
435
- model_token = rt.InferenceSession(model_token_path, providers=providers)
436
  tokenizer = get_tokenizer(config)
437
- models[name] = [model_base, model_token, tokenizer]
438
  for lora_name, lora_repo in loras.items():
439
  model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
440
  model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
441
- model_base = rt.InferenceSession(model_base_path, providers=providers)
442
- model_token = rt.InferenceSession(model_token_path, providers=providers)
443
  tokenizer = get_tokenizer(config)
444
- models[f"{name} with {lora_name} lora"] = [model_base, model_token, tokenizer]
445
 
446
  load_javascript()
447
  app = gr.Blocks()
 
170
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
171
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
172
  model = models[model_name]
173
+ model_base = rt.InferenceSession(model[0], providers=providers)
174
+ model_token = rt.InferenceSession(model[1], providers=providers)
175
  tokenizer = model[2]
176
+ model = [model_base, model_token, tokenizer]
177
  bpm = int(bpm)
178
  if time_sig == "auto":
179
  time_sig = None
 
427
  ]
428
  }
429
  models = {}
430
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
431
 
432
  for name, (repo_id, path, config, loras) in models_info.items():
433
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
434
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
 
 
435
  tokenizer = get_tokenizer(config)
436
+ models[name] = [model_base_path, model_token_path, tokenizer]
437
  for lora_name, lora_repo in loras.items():
438
  model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
439
  model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
 
 
440
  tokenizer = get_tokenizer(config)
441
+ models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
442
 
443
  load_javascript()
444
  app = gr.Blocks()