Staticaliza commited on
Commit
58fe0e8
1 Parent(s): ee89a2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -30,12 +30,14 @@ ode_method = "euler"
30
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
31
  ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
32
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
 
33
  model = CFM(
34
  transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
35
  mel_spec_kwargs=dict(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length),
36
  odeint_kwargs=dict(method=ode_method),
37
  vocab_char_map=vocab_char_map,
38
  ).to(DEVICE)
 
39
  model = load_checkpoint(model, ckpt_path, DEVICE, use_ema = True)
40
  return model
41
 
 
30
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
31
  ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
32
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
33
+
34
  model = CFM(
35
  transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
36
  mel_spec_kwargs=dict(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length),
37
  odeint_kwargs=dict(method=ode_method),
38
  vocab_char_map=vocab_char_map,
39
  ).to(DEVICE)
40
+
41
  model = load_checkpoint(model, ckpt_path, DEVICE, use_ema = True)
42
  return model
43