TDN-M commited on
Commit
00954c4
·
verified ·
1 Parent(s): a02f58f

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +65 -79
tts.py CHANGED
@@ -2,46 +2,35 @@ import os
2
  import re
3
  import torch
4
  import torchaudio
5
- from huggingface_hub import hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
8
  from vinorm import TTSnorm
9
- from torch.amp import autocast
10
 
11
  # Cấu hình đường dẫn và tải mô hình
12
  checkpoint_dir = "model/"
13
  repo_id = "capleaf/viXTTS"
14
  use_deepspeed = False
15
 
16
- # Kiểm tra GPU hỗ trợ FP16
17
- if torch.cuda.is_available():
18
- device = "cuda"
19
- if "A100" in torch.cuda.get_device_name(0):
20
- print("Đang sử dụng GPU A100 với hỗ trợ FP16.")
21
- use_fp16 = True
22
- else:
23
- print(f"Đang sử dụng GPU: {torch.cuda.get_device_name(0)}")
24
- use_fp16 = False
25
- else:
26
- device = "cpu"
27
- use_fp16 = False
28
 
29
  # Tạo thư mục nếu chưa tồn tại
30
  os.makedirs(checkpoint_dir, exist_ok=True)
31
 
32
  # Kiểm tra và tải các file cần thiết
33
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
34
- for file in required_files:
35
- file_path = os.path.join(checkpoint_dir, file)
36
- if not os.path.exists(file_path):
37
- try:
38
- hf_hub_download(
39
- repo_id=repo_id if file != "speakers_xtts.pth" else "coqui/XTTS-v2",
40
- filename=file,
41
- local_dir=checkpoint_dir,
42
- )
43
- except Exception as e:
44
- raise RuntimeError(f"Không thể tải file {file} từ Hugging Face Hub: {str(e)}")
 
45
 
46
  # Tải cấu hình và mô hình
47
  xtts_config = os.path.join(checkpoint_dir, "config.json")
@@ -49,74 +38,71 @@ config = XttsConfig()
49
  config.load_json(xtts_config)
50
  MODEL = Xtts.init_from_config(config)
51
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
 
 
52
  MODEL.to(device)
53
 
54
- # Danh sách ngôn ngữ được hỗ trợ
55
  supported_languages = ["vi", "en"]
56
 
57
  def normalize_vietnamese_text(text):
58
- try:
59
- text = (
60
- TTSnorm(text, unknown=False, lower=False, rule=True)
61
- .replace("..", ".")
62
- .replace("!.", "!")
63
- .replace("?.", "?")
64
- .replace(" .", ".")
65
- .replace(" ,", ",")
66
- .replace('"', "")
67
- .replace("'", "")
68
- .replace("AI", "Ây Ai")
69
- .replace("A.I", "Ây Ai")
70
- )
71
- return text
72
- except Exception as e:
73
- raise RuntimeError(f"Lỗi khi chuẩn hóa văn bản: {str(e)}")
74
-
75
- def generate_speech(
76
- text,
77
- language="vi",
78
- speaker_wav=None,
79
- normalize_text=True,
80
- repetition_penalty=5.0,
81
- temperature=0.75,
82
- ):
83
  if language not in supported_languages:
84
- raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ. Các ngôn ngữ được hỗ trợ: {', '.join(supported_languages)}")
 
85
  if len(text) < 2:
86
- raise ValueError("Văn bản quá ngắn.")
87
- if speaker_wav and not os.path.isfile(speaker_wav):
88
- raise ValueError(f"File speaker_wav không tồn tại: {speaker_wav}")
89
 
90
  try:
 
91
  if normalize_text and language == "vi":
92
  text = normalize_vietnamese_text(text)
93
 
94
- with torch.no_grad():
95
- with autocast(device_type='cuda', enabled=use_fp16):
96
- gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
97
- audio_path=speaker_wav,
98
- gpt_cond_len=30 if device == "cuda" else 15,
99
- gpt_cond_chunk_len=8 if device == "cuda" else 4,
100
- max_ref_length=60 if device == "cuda" else 30,
101
- )
102
- out = MODEL.inference(
103
- text,
104
- language,
105
- gpt_cond_latent,
106
- speaker_embedding,
107
- repetition_penalty=repetition_penalty,
108
- temperature=temperature,
109
- enable_text_splitting=True,
110
- )
111
-
112
- output_dir = "outputs/"
113
- os.makedirs(output_dir, exist_ok=True)
114
- output_file = os.path.join(output_dir, f"output_{os.urandom(4).hex()}.wav")
115
- torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
116
 
117
- if device == "cuda":
118
- torch.cuda.empty_cache()
 
119
 
120
  return output_file
 
121
  except Exception as e:
122
  raise RuntimeError(f"Lỗi khi tạo giọng nói: {str(e)}")
 
2
  import re
3
  import torch
4
  import torchaudio
5
+ from huggingface_hub import snapshot_download, hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
8
  from vinorm import TTSnorm
 
9
 
10
  # Cấu hình đường dẫn và tải mô hình
11
  checkpoint_dir = "model/"
12
  repo_id = "capleaf/viXTTS"
13
  use_deepspeed = False
14
 
15
+ device = "cuda" if torch.cuda.is_available() and "T4" in torch.cuda.get_device_name(0) else "cpu"
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Tạo thư mục nếu chưa tồn tại
18
  os.makedirs(checkpoint_dir, exist_ok=True)
19
 
20
  # Kiểm tra và tải các file cần thiết
21
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
22
+ files_in_dir = os.listdir(checkpoint_dir)
23
+ if not all(file in files_in_dir for file in required_files):
24
+ snapshot_download(
25
+ repo_id=repo_id,
26
+ repo_type="model",
27
+ local_dir=checkpoint_dir,
28
+ )
29
+ hf_hub_download(
30
+ repo_id="coqui/XTTS-v2",
31
+ filename="speakers_xtts.pth",
32
+ local_dir=checkpoint_dir,
33
+ )
34
 
35
  # Tải cấu hình và mô hình
36
  xtts_config = os.path.join(checkpoint_dir, "config.json")
 
38
  config.load_json(xtts_config)
39
  MODEL = Xtts.init_from_config(config)
40
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
41
+
42
+ # Tải mô hình vào thiết bị phù hợp
43
  MODEL.to(device)
44
 
45
+ # Danh sách ngôn ngữ được hỗ trợ (chỉ tiếng Việt và tiếng Anh)
46
  supported_languages = ["vi", "en"]
47
 
48
  def normalize_vietnamese_text(text):
49
+ """
50
+ Chuẩn hóa văn bản tiếng Việt.
51
+ """
52
+ text = (
53
+ TTSnorm(text, unknown=False, lower=False, rule=True)
54
+ .replace("..", ".")
55
+ .replace("!.", "!")
56
+ .replace("?.", "?")
57
+ .replace(" .", ".")
58
+ .replace(" ,", ",")
59
+ .replace('"', "")
60
+ .replace("'", "")
61
+ .replace("AI", "Ây Ai")
62
+ .replace("A.I", "Ây Ai")
63
+ )
64
+ return text
65
+
66
+ def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
67
+ """
68
+ Tạo giọng nói từ văn bản.
69
+ """
 
 
 
 
70
  if language not in supported_languages:
71
+ raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ. Chỉ hỗ trợ tiếng Việt (vi) tiếng Anh (en).")
72
+
73
  if len(text) < 2:
74
+ raise ValueError("Văn bản quá ngắn. Vui lòng nhập văn bản dài hơn.")
 
 
75
 
76
  try:
77
+ # Chuẩn hóa văn bản nếu cần
78
  if normalize_text and language == "vi":
79
  text = normalize_vietnamese_text(text)
80
 
81
+ # Lấy latent và embedding từ file âm thanh mẫu
82
+ with torch.no_grad(): # Tắt tính gradient để tiết kiệm bộ nhớ
83
+ gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
84
+ audio_path=speaker_wav,
85
+ gpt_cond_len=30 if device == "cuda" else 15, # Tăng độ dài khi dùng GPU
86
+ gpt_cond_chunk_len=8 if device == "cuda" else 4,
87
+ max_ref_length=60 if device == "cuda" else 30,
88
+ )
89
+
90
+ # Tạo giọng nói
91
+ out = MODEL.inference(
92
+ text,
93
+ language,
94
+ gpt_cond_latent,
95
+ speaker_embedding,
96
+ repetition_penalty=5.0,
97
+ temperature=0.75,
98
+ enable_text_splitting=True,
99
+ )
 
 
 
100
 
101
+ # Lưu file âm thanh
102
+ output_file = "output.wav"
103
+ torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
104
 
105
  return output_file
106
+
107
  except Exception as e:
108
  raise RuntimeError(f"Lỗi khi tạo giọng nói: {str(e)}")