File size: 7,295 Bytes
56fd60f
 
 
 
0cbbac2
 
56fd60f
0cbbac2
56fd60f
 
 
 
 
 
 
 
0cbbac2
e2228b9
 
56fd60f
 
 
0cbbac2
 
 
 
56fd60f
0cbbac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56fd60f
 
 
 
0cbbac2
 
 
 
56fd60f
0cbbac2
 
 
 
 
 
 
 
 
 
 
 
 
56fd60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cbbac2
56fd60f
 
 
 
 
 
 
0cbbac2
 
56fd60f
 
 
0cbbac2
56fd60f
0cbbac2
 
56fd60f
 
0cbbac2
 
56fd60f
0cbbac2
 
 
56fd60f
334a9cd
56fd60f
 
 
 
 
 
 
 
9ffe87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56fd60f
 
9ffe87e
0cbbac2
56fd60f
 
 
 
 
 
0cbbac2
56fd60f
 
0cbbac2
56fd60f
0cbbac2
56fd60f
0cbbac2
56fd60f
0cbbac2
334a9cd
 
 
 
56fd60f
 
 
 
0cbbac2
56fd60f
0cbbac2
56fd60f
 
0cbbac2
56fd60f
 
 
 
 
 
 
 
0cbbac2
56fd60f
0cbbac2
56fd60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cbbac2
 
56fd60f
 
e2228b9
334a9cd
56fd60f
 
 
334a9cd
 
 
56fd60f
9ffe87e
56fd60f
0cbbac2
56fd60f
 
 
0cbbac2
56fd60f
0cbbac2
56fd60f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import time
import traceback
import torch
import numpy as np
import librosa
from fairseq import checkpoint_utils
from rmvpe import RMVPE
from config import Config
from lib.infer_pack.models import (
    SynthesizerTrnMs256NSFsid,
    SynthesizerTrnMs256NSFsid_nono,
    SynthesizerTrnMs768NSFsid,
    SynthesizerTrnMs768NSFsid_nono,
)
from vc_infer_pipeline import VC
import uuid
import tempfile  # Ensure this is imported
import asyncio  # Ensure this is imported

config = Config()

# Global models loaded once
hubert_model = None
rmvpe_model = None
model_cache = {}  # Cache for RVC models

def load_hubert():
    global hubert_model
    if hubert_model is None:
        print("Loading Hubert model...")
        models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
            ["hubert_base.pt"],
            suffix="",
        )
        hubert_model = models[0]
        hubert_model = hubert_model.to(config.device)
        if config.is_half:
            hubert_model = hubert_model.half()
        else:
            hubert_model = hubert_model.float()
        hubert_model.eval()
        print("Hubert model loaded.")
    return hubert_model

def load_rmvpe():
    global rmvpe_model
    if rmvpe_model is None:
        print("Loading RMVPE model...")
        rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
        print("RMVPE model loaded.")
    return rmvpe_model

def get_unique_filename(extension):
    return f"{uuid.uuid4()}.{extension}"

def get_model_names():
    model_root = "weights"  # Assuming this is where your models are stored
    return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]

def model_data(model_name):
    global model_cache
    if model_name in model_cache:
        # Return cached model data
        return model_cache[model_name]

    model_root = "weights"
    pth_files = [
        f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".pth")
    ]
    if not pth_files:
        raise FileNotFoundError(f"No .pth file found for model '{model_name}'")
    pth_path = f"{model_root}/{model_name}/{pth_files[0]}"
    print(f"Loading model from {pth_path}")
    cpt = torch.load(pth_path, map_location="cpu")
    tgt_sr = cpt["config"][-1]
    cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]  # n_spk
    if_f0 = cpt.get("f0", 1)
    version = cpt.get("version", "v1")
    if version == "v1":
        if if_f0 == 1:
            net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
        else:
            net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
    elif version == "v2":
        if if_f0 == 1:
            net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
        else:
            net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
    else:
        raise ValueError("Unknown version")

    del net_g.enc_q
    net_g.load_state_dict(cpt["weight"], strict=False)
    net_g.eval().to(config.device)
    if config.is_half:
        net_g = net_g.half()
    else:
        net_g = net_g.float()
    print(f"Model '{model_name}' loaded.")

    vc = VC(tgt_sr, config)

    index_files = [
        f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".index")
    ]
    if index_files:
        index_file = f"{model_root}/{model_name}/{index_files[0]}"
        print(f"Index file found: {index_file}")
    else:
        index_file = ""
        print("No index file found.")

    # Cache the loaded model data
    model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
    return tgt_sr, net_g, vc, version, index_file, if_f0

async def tts(
    model_name,
    tts_text,
    tts_voice,
    index_rate,
    use_uploaded_voice,
    uploaded_voice,
):
    try:
        # Load models if not already loaded
        load_hubert()
        load_rmvpe()

        # Default values for parameters used in EdgeTTS
        f0_up_key = 0  # Default pitch adjustment
        f0_method = "rmvpe"  # Default pitch extraction method
        protect = 0.33  # Default protect value
        filter_radius = 3
        resample_sr = 0
        rms_mix_rate = 0.25
        edge_time = 0  # Initialize edge_time

        edge_output_filename = get_unique_filename("mp3")
        audio = None
        sr = 16000  # Default sample rate

        if use_uploaded_voice:
            if uploaded_voice is None:
                return {"error": "No voice file uploaded."}, None, None

            # Process the uploaded voice file
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
                tmp_file.write(uploaded_voice)
                uploaded_file_path = tmp_file.name

            audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
            input_audio_path = uploaded_file_path
        else:
            # EdgeTTS processing
            import edge_tts
            t0 = time.time()
            speed = 0  # Default speech speed
            speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
            communicate = edge_tts.Communicate(
                tts_text, tts_voice, rate=speed_str
            )
            try:
                await asyncio.wait_for(communicate.save(edge_output_filename), timeout=30)
            except asyncio.TimeoutError:
                return {"error": "EdgeTTS operation timed out"}, None, None
            t1 = time.time()
            edge_time = t1 - t0

            audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
            input_audio_path = edge_output_filename

        # Load the specified RVC model
        tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)

        # Set RMVPE model for pitch extraction
        if f0_method == "rmvpe":
            vc.model_rmvpe = rmvpe_model

        # Perform voice conversion pipeline
        times = [0, 0, 0]
        audio_opt = vc.pipeline(
            hubert_model,
            net_g,
            0,  # Speaker ID
            audio,
            input_audio_path,
            times,
            f0_up_key,
            f0_method,
            index_file,
            index_rate,
            if_f0,
            filter_radius,
            tgt_sr,
            resample_sr,
            rms_mix_rate,
            version,
            protect,
            None,
        )

        if tgt_sr != resample_sr and resample_sr >= 16000:
            tgt_sr = resample_sr

        info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
        print(info)
        return (
            info,
            edge_output_filename,
            (tgt_sr, audio_opt),
        )

    except asyncio.CancelledError:
        print("TTS operation was cancelled")
        return {"error": "Operation cancelled"}, None, None
    except EOFError:
        info = "Output not valid. This may occur when input text and speaker do not match."
        print(info)
        return {"error": info}, None, None
    except Exception as e:
        traceback_info = traceback.format_exc()
        print(traceback_info)
        return {"error": str(e)}, None, None

# Voice mapping dictionary
voice_mapping = {
    "Mongolian Male": "mn-MN-BataaNeural",
    "Mongolian Female": "mn-MN-YesuiNeural"
}