mhrahmani commited on
Commit
6ff371b
·
1 Parent(s): e4b3e78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -13,22 +13,27 @@ MODEL_INFO = [
13
  # MODEL_NAMES = [info[0] for info in MODEL_INFO]
14
 
15
  MODEL_NAMES = [
16
- "vits checkoint 57000",
17
  # Add other model names similarly...
18
  ]
19
 
20
-
21
  MAX_TXT_LEN = 400
22
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
23
 
24
- # Download models
25
- for model_name, model_file, config_file, repo_name in MODEL_INFO:
26
- os.makedirs(model_name, exist_ok=True)
27
- print(f"|> Downloading: {model_name}")
28
 
29
- # Use hf_hub_download to download models from private Hugging Face repositories
30
- hf_hub_download(repo_id=repo_name, filename=model_file, cache_dir=model_name, use_auth_token=TOKEN)
31
- hf_hub_download(repo_id=repo_name, filename=config_file, cache_dir=model_name, use_auth_token=TOKEN)
 
 
 
 
 
 
32
 
33
 
34
  def synthesize(text: str, model_name: str) -> str:
@@ -37,7 +42,7 @@ def synthesize(text: str, model_name: str) -> str:
37
  text = text[:MAX_TXT_LEN]
38
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
39
 
40
- synthesizer = Synthesizer(f"{model_name}/{model_file}", f"{model_name}/{config_file}")
41
  if synthesizer is None:
42
  raise NameError("Model not found")
43
 
 
13
  # MODEL_NAMES = [info[0] for info in MODEL_INFO]
14
 
15
  MODEL_NAMES = [
16
+ "vits checkpoint 57000",
17
  # Add other model names similarly...
18
  ]
19
 
 
20
  MAX_TXT_LEN = 400
21
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
22
 
23
+ # # Download models
24
+ # for model_name, model_file, config_file, repo_name in MODEL_INFO:
25
+ # os.makedirs(model_name, exist_ok=True)
26
+ # print(f"|> Downloading: {model_name}")
27
 
28
+ # # Use hf_hub_download to download models from private Hugging Face repositories
29
+ # hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
30
+ # hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
31
+
32
+ repo_name = "mhrahmani/persian-tts-vits-0"
33
+ filename = "checkpoint_57000.pth"
34
+
35
+ model_file = hf_hub_download(repo_name, filename, use_auth_token=TOKEN)
36
+ config_file = hf_hub_download(repo_name, "config.json", use_auth_token=TOKEN)
37
 
38
 
39
  def synthesize(text: str, model_name: str) -> str:
 
42
  text = text[:MAX_TXT_LEN]
43
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
44
 
45
+ synthesizer = Synthesizer(model_file, config_file)
46
  if synthesizer is None:
47
  raise NameError("Model not found")
48