ThreadAbort commited on
Commit
ca02c10
·
1 Parent(s): 29d19bd

repo. change

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -49,8 +49,8 @@ speed = 1.0
49
  # fix_duration = 27 # None or float (duration in seconds)
50
  fix_duration = None
51
 
52
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
56
  transformer=model_cls(
@@ -79,13 +79,13 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
79
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
81
 
82
- F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS", DiT, F5TTS_model_cfg, 1200000)
83
- E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS", UNetT, E2TTS_model_cfg, 1200000)
84
 
85
  @spaces.GPU
86
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
  print(gen_text)
88
- if model.predict(gen_text)['toxicity'] > 0.8:
89
  print("Flagged for toxicity:", gen_text)
90
  raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
  gr.Info("Converting audio...")
 
49
  # fix_duration = 27 # None or float (duration in seconds)
50
  fix_duration = None
51
 
52
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step,repoid):
53
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repoid}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
56
  transformer=model_cls(
 
79
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
81
 
82
+ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000, "F5-TTS")
83
+ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000, "E2-TTS")
84
 
85
  @spaces.GPU
86
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
  print(gen_text)
88
+ if model.predict(gen_text)['toxicity'] > 0.8:
89
  print("Flagged for toxicity:", gen_text)
90
  raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
  gr.Info("Converting audio...")