rrg92 commited on
Commit
7c3a67d
1 Parent(s): 051922b
Files changed (5) hide show
  1. Dockerfile +24 -0
  2. app.py +350 -0
  3. docker-compose.yml +40 -0
  4. requirements.txt +17 -0
  5. xtts.py +192 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
2
+ ARG DEBIAN_FRONTEND=noninteractive
3
+
4
+ RUN apt-get update && \
5
+ apt-get install --no-install-recommends -y sox libsox-fmt-all curl wget gcc git git-lfs build-essential libaio-dev libsndfile1 ssh ffmpeg && \
6
+ apt-get clean && apt-get -y autoremove
7
+
8
+ WORKDIR /app
9
+ COPY requirements.txt .
10
+ RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \
11
+ && python -m pip cache purge
12
+
13
+ RUN python -m unidic download
14
+ RUN mkdir -p /app/tts_models
15
+
16
+ COPY xtts.py .
17
+ COPY app.py .
18
+
19
+ #Mark this 1 if you have older card
20
+ #ENV NVIDIA_DISABLE_REQUIRE=0
21
+
22
+ ENV NUM_THREADS=2
23
+ EXPOSE 80
24
+ CMD ["python","app.py"]
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import tempfile
4
+ import json
5
+ import os
6
+ from os.path import abspath
7
+ import zipfile
8
+ import random
9
+ import xtts
10
+
11
+
12
+ DO_CHECK = os.getenv('DO_CHECK', '1')
13
+ OUTPUT = "./demo_outputs"
14
+ cloned_speakers = {}
15
+
16
+ print("Preparing file structure...")
17
+ if not os.path.exists(OUTPUT):
18
+ os.mkdir(OUTPUT)
19
+ os.mkdir(os.path.join(OUTPUT, "cloned_speakers"))
20
+ os.mkdir(os.path.join(OUTPUT, "generated_audios"))
21
+ elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")):
22
+ print("Loading existing cloned speakers...")
23
+ for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")):
24
+ if file.endswith(".json"):
25
+ with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp:
26
+ cloned_speakers[file[:-5]] = json.load(fp)
27
+ print("Available cloned speakers:", ", ".join(cloned_speakers.keys()))
28
+
29
+ AUDIOS_DIR = os.path.join("demo_outputs", "generated_audios");
30
+ ZIP_DIR = os.path.join("zip_outputs");
31
+
32
+ print("Checking zip at", ZIP_DIR)
33
+ if not os.path.exists(ZIP_DIR):
34
+ os.mkdir(ZIP_DIR)
35
+
36
+
37
+ try:
38
+ print("Getting metadata from server ...")
39
+ LANUGAGES = xtts.get_languages()
40
+ print("Available languages:", ", ".join(LANUGAGES))
41
+ STUDIO_SPEAKERS = xtts.get_speakers()
42
+ print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys()))
43
+ except:
44
+ raise Exception("Please make sure the server is running first.")
45
+
46
+
47
+ def ExtractVars(input_string):
48
+ # Split the string into lines
49
+ lines = input_string.split('\n')
50
+
51
+ # Initialize an empty dictionary to store key-value pairs
52
+ result_dict = {
53
+ 'prefix': None,
54
+ 'name': '',
55
+ 'speaker': None,
56
+ 'num': None,
57
+ }
58
+
59
+ # List to hold lines that do not start with '!'
60
+ filtered_lines = []
61
+
62
+ # Iterate through each line
63
+ for line in lines:
64
+ # Check if the line starts with '!'
65
+ if line.strip().startswith('!'):
66
+
67
+ # Try to split the line into key and value parts
68
+ try:
69
+ # Split on '=' and strip whitespace from key and value
70
+ key, value = line.strip()[1:].split('=')
71
+ key = key.strip()
72
+ value = value.strip()
73
+ # Add to dictionary
74
+ result_dict[key] = value
75
+ except ValueError:
76
+ # Handle the case where there is no '=' or improper format
77
+ continue
78
+ elif len(line.strip()) > 0:
79
+ # Add the line to filtered_lines if it doesn't start with '!'
80
+ filtered_lines.append(line)
81
+
82
+ # Join the filtered lines back into a single string
83
+ filtered_string = '\n'.join(filtered_lines)
84
+ return result_dict, filtered_string
85
+
86
+
87
+ def FindSpeakerByName(name, speakerType):
88
+
89
+ srcItems = STUDIO_SPEAKERS if speakerType == "Studio" else cloned_speakers;
90
+
91
+ for key, value in srcItems.items():
92
+
93
+ if key == name:
94
+ return key,value
95
+
96
+ if key.split(" ")[0] == name:
97
+ return key,value;
98
+
99
+
100
+ def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names):
101
+ embeddings = xtts.predict_speaker(open(upload_file,"rb"))
102
+ with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp:
103
+ json.dump(embeddings, fp)
104
+ cloned_speakers[clone_speaker_name] = embeddings
105
+ cloned_speaker_names.append(clone_speaker_name)
106
+ return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown(choices=cloned_speaker_names)
107
+
108
+ def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang, temperature
109
+ ,speed,top_p,top_k, AllFileList,progress=gr.Progress()
110
+ ):
111
+ embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom]
112
+
113
+ # break at line!
114
+ lines = text.split("---");
115
+ totalLines = len(lines);
116
+ print("Total parts:", len(lines))
117
+
118
+ audioNum = 0;
119
+
120
+ DefaultPrefix = next(tempfile._get_candidate_names());
121
+
122
+ CurrentPrefix = DefaultPrefix
123
+
124
+
125
+ AudioList = [];
126
+ for line in progress.tqdm(lines, desc="Gerando fala..."):
127
+ audioNum += 1;
128
+
129
+ textVars,cleanLine = ExtractVars(line)
130
+
131
+ if textVars['prefix']:
132
+ CurrentPrefix = textVars['prefix']
133
+
134
+ audioName = textVars['name'];
135
+
136
+ if audioName:
137
+ audioName = '_'+audioName
138
+
139
+ num = textVars['num'];
140
+
141
+ if not num:
142
+ num = audioNum;
143
+
144
+ path = CurrentPrefix +"_n_" + str(num)+audioName+".wav"
145
+
146
+ print("Generating audio for line", num, 'sequence', audioNum);
147
+
148
+ speaker = textVars['speaker'];
149
+
150
+ if not speaker:
151
+ speaker = speaker_name_studio if speaker_type == 'Studio' else speaker_name_custom
152
+
153
+ speakerName,embeddings = FindSpeakerByName(speaker, speaker_type)
154
+
155
+ if not speakerName:
156
+ raise ValueError("InvalidSpeaker: "+speakerName)
157
+
158
+ ipts = xtts.TTSInputs(
159
+ speaker_embedding=embeddings["speaker_embedding"],
160
+ gpt_cond_latent=embeddings["gpt_cond_latent"],
161
+ text=cleanLine,
162
+ language=lang,
163
+ temperature=temperature,
164
+ speed=speed,
165
+ top_k=top_k,
166
+ top_p=top_p
167
+ )
168
+
169
+ generated_audio = xtts.predict_speech(ipts)
170
+
171
+ print("Audio generated.. Saving to", path);
172
+ generated_audio_path = os.path.join(AUDIOS_DIR, path)
173
+ with open(generated_audio_path, "wb") as fp:
174
+ fp.write(base64.b64decode(generated_audio))
175
+ AudioList.append(fp.name);
176
+
177
+ AllFileList.clear();
178
+ AllFileList.extend(AudioList);
179
+
180
+ return gr.Dropdown(
181
+ label="Generated Audios",
182
+ choices=list(AudioList),
183
+ value=AudioList[0]
184
+ )
185
+
186
+ def get_file_content(f):
187
+ if len(f) > 0:
188
+ return f[0];
189
+
190
+ return None;
191
+
192
+
193
+ def UpdateFileList(DirListState):
194
+ DirListState.clear();
195
+ DirListState.extend( os.listdir(AUDIOS_DIR) )
196
+
197
+ def audio_list_update(d):
198
+ fullPath = abspath(d)
199
+ return fullPath
200
+
201
+ def ZipAndDownload(files):
202
+ allFiles = files
203
+
204
+ DefaultPrefix = next(tempfile._get_candidate_names());
205
+
206
+ zipFile = abspath( os.path.join(ZIP_DIR, DefaultPrefix + ".zip") );
207
+
208
+
209
+ with zipfile.ZipFile(zipFile, 'w') as zipMe:
210
+ for file in allFiles:
211
+ print("Zipping", file);
212
+ zipMe.write(abspath(file), os.path.basename(file), compress_type=zipfile.ZIP_DEFLATED)
213
+
214
+ print("Pronto", zipFile);
215
+
216
+ return '<a href="/file='+zipFile+'">If donwload dont starts, click here</a>';
217
+
218
+
219
+ js = """
220
+ function DetectDownloadLink(){
221
+ console.log('Configuring AutoDonwloadObservr...');
222
+ let hiddenLink = document.getElementById("DonwloadLink");
223
+ let onChange= function(mutations){
224
+
225
+ for (const mutation of mutations) {
226
+ if (mutation.type !== 'childList')
227
+ continue;
228
+
229
+ for (const addedNode of mutation.addedNodes) {
230
+ if (addedNode.nodeName === 'A') {
231
+ location.href = addedNode.href;
232
+ }
233
+ }
234
+
235
+ }
236
+ }
237
+
238
+ let config = { attributes: true, childList: true, subtree: true, attributeFilter: ["href"] }
239
+ let obs = new MutationObserver(onChange);
240
+ obs.observe(hiddenLink, config);
241
+ }
242
+ """
243
+
244
+ with gr.Blocks(js=js) as demo:
245
+ defaultSpeaker = "Dionisio Schuyler"
246
+ cloned_speaker_names = gr.State(list(cloned_speakers.keys()))
247
+ AllFileList = gr.State(list([]))
248
+
249
+
250
+ with gr.Tab("TTS"):
251
+ with gr.Column() as row4:
252
+ with gr.Row() as col4:
253
+ speaker_name_studio = gr.Dropdown(
254
+ label="Studio speaker",
255
+ choices=STUDIO_SPEAKERS.keys(),
256
+ value=defaultSpeaker if defaultSpeaker in STUDIO_SPEAKERS.keys() else None,
257
+ )
258
+ speaker_name_custom = gr.Dropdown(
259
+ label="Cloned speaker",
260
+ choices=cloned_speaker_names.value,
261
+ value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None,
262
+ )
263
+ speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio")
264
+ with gr.Column() as rowAdvanced:
265
+ with gr.Row() as rowAdvanced:
266
+ temperature = gr.Slider(0.00, 1.00, 0.5, step=0.05, label="Temperature", info="Choose between 0 and 1")
267
+ top_p = gr.Slider(0.00, 1.00, 0.8, step=0.05, label="TOP P", info="Choose between 0 and 1")
268
+ top_k = gr.Number(label="TOP K",value=50)
269
+ speed = gr.Slider(0.00, 1000.00, 1.0, step=0.1, label="Speed", info="Speed (0 to 1000)")
270
+ with gr.Column() as col2:
271
+ lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="pt")
272
+ text = gr.Textbox(label="text",lines=4, value="A quick brown fox jumps over the lazy dog.")
273
+ tts_button = gr.Button(value="TTS")
274
+ with gr.Column() as col3:
275
+ # FileList = gr.FileExplorer(
276
+ # glob="*.wav",
277
+ # # value=["themes/utils"],
278
+ # ignore_glob="**/__init__.py",
279
+ # root_dir=AUDIOS_DIR,
280
+ # interactive = True,
281
+ # value=DirectoryList.value
282
+ # )
283
+
284
+ AudioList = gr.Dropdown(
285
+ label="Generated Audios",
286
+ choices=['a','b']
287
+ ,interactive=True
288
+ )
289
+
290
+ generated_audio = gr.Audio(label="Audio Play", autoplay=True)
291
+ AudioList.change(fn=audio_list_update, inputs=[AudioList], outputs=[generated_audio])
292
+
293
+ dummyHtml = gr.HTML(elem_id = "DonwloadLink", render = False);
294
+ downloadAll = gr.DownloadButton("Download All Files")
295
+ downloadAll.click(ZipAndDownload, inputs=[AllFileList], outputs=[dummyHtml]);
296
+ dummyHtml.render();
297
+
298
+
299
+ with gr.Tab("Clone a new speaker"):
300
+ with gr.Column() as col1:
301
+ upload_file = gr.Audio(label="Upload reference audio", type="filepath")
302
+ clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker")
303
+ clone_button = gr.Button(value="Clone speaker")
304
+
305
+ clone_button.click(
306
+ fn=clone_speaker,
307
+ inputs=[upload_file, clone_speaker_name, cloned_speaker_names],
308
+ outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom],
309
+ )
310
+
311
+ tts_button.click(
312
+ fn=tts,
313
+ inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang, temperature
314
+ ,speed,top_p,top_k,AllFileList
315
+ ],
316
+ outputs=[AudioList],
317
+ )
318
+
319
+ if __name__ == "__main__" and DO_CHECK == "1":
320
+ print("Warming up server... Checking server healthy...")
321
+
322
+ speakerName, embs = random.choice(list(STUDIO_SPEAKERS.items()));
323
+
324
+ print("Testing with", speakerName);
325
+
326
+ ipts = xtts.TTSInputs(
327
+ speaker_embedding=embs["speaker_embedding"],
328
+ gpt_cond_latent=embs["gpt_cond_latent"],
329
+ text="This is a warmup request.",
330
+ language="en",
331
+ temperature=0.5,
332
+ speed=1.0,
333
+ top_k=50,
334
+ top_p=0.8
335
+ )
336
+
337
+ resp = xtts.predict_speech(ipts)
338
+
339
+ print(" TEST OK")
340
+
341
+
342
+ if __name__ == "__main__":
343
+ print("STARTING...")
344
+ demo.launch(
345
+ share=False,
346
+ debug=False,
347
+ server_port=80,
348
+ server_name="0.0.0.0",
349
+ allowed_paths=[ZIP_DIR]
350
+ )
docker-compose.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: webui-docker
2
+
3
+ volumes:
4
+ servel-model-root:
5
+
6
+ services:
7
+
8
+ xtts:
9
+ build:
10
+ context: .
11
+ dockerfile: Dockerfile
12
+ environment:
13
+ COQUI_TOS_AGREED: 1
14
+ CUSTOM_MODEL_PATH: /root/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2
15
+ ports:
16
+ - 3000:80
17
+ expose:
18
+ - 80
19
+ volumes:
20
+ - type: volume
21
+ source: servel-model-root
22
+ target: /root/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2
23
+ stdin_open: true # docker run -i
24
+ tty: true # docker run -t
25
+ deploy:
26
+ resources:
27
+ reservations:
28
+ devices:
29
+ - driver: nvidia
30
+ count: all
31
+ capabilities: [gpu]
32
+ healthcheck:
33
+ test: wget --no-verbose --tries=1 http://localhost || exit 1
34
+ interval: 5s
35
+ timeout: 30s
36
+ retries: 3
37
+ start_period: 5m
38
+
39
+
40
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ gradio
5
+ numpy
6
+ TTS @ git+https://github.com/coqui-ai/TTS@fa28f99f1508b5b5366539b2149963edcb80ba62
7
+ uvicorn[standard]==0.23.2
8
+ deepspeed
9
+ pydantic
10
+ python-multipart==0.0.6
11
+ typing-extensions>=4.8.0
12
+ cutlet
13
+ mecab-python3==1.0.6
14
+ unidic-lite==1.0.8
15
+ unidic==1.1.0
16
+
17
+
xtts.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import tempfile
5
+ import wave
6
+ import torch
7
+ import numpy as np
8
+ from typing import List
9
+ from pydantic import BaseModel
10
+
11
+ from TTS.tts.configs.xtts_config import XttsConfig
12
+ from TTS.tts.models.xtts import Xtts
13
+ from TTS.utils.generic_utils import get_user_data_dir
14
+ from TTS.utils.manage import ModelManager
15
+
16
+ torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
17
+ device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu")
18
+ if not torch.cuda.is_available() and device == "cuda":
19
+ raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.")
20
+
21
+ custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
22
+
23
+ if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
24
+ model_path = custom_model_path
25
+ print("Loading custom model from", model_path, flush=True)
26
+ else:
27
+ print("Loading default model", flush=True)
28
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
29
+ print("Downloading XTTS Model:", model_name, flush=True)
30
+ ModelManager().download_model(model_name)
31
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
32
+ print("XTTS Model downloaded", flush=True)
33
+
34
+ print("Loading XTTS", flush=True)
35
+ config = XttsConfig()
36
+ config.load_json(os.path.join(model_path, "config.json"))
37
+ model = Xtts.init_from_config(config)
38
+ model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
39
+ model.to(device)
40
+ print("XTTS Loaded.", flush=True)
41
+
42
+ print("Running XTTS Server ...", flush=True)
43
+
44
+
45
+
46
+ # @app.post("/clone_speaker")
47
+ def predict_speaker(wav_file):
48
+ """Compute conditioning inputs from reference audio file."""
49
+ temp_audio_name = next(tempfile._get_candidate_names())
50
+ with open(temp_audio_name, "wb") as temp, torch.inference_mode():
51
+ temp.write(io.BytesIO(wav_file.read()).getbuffer())
52
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
53
+ temp_audio_name
54
+ )
55
+ return {
56
+ "gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(),
57
+ "speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(),
58
+ }
59
+
60
+
61
+ def postprocess(wav):
62
+ """Post process the output waveform"""
63
+ if isinstance(wav, list):
64
+ wav = torch.cat(wav, dim=0)
65
+ wav = wav.clone().detach().cpu().numpy()
66
+ wav = wav[None, : int(wav.shape[0])]
67
+ wav = np.clip(wav, -1, 1)
68
+ wav = (wav * 32767).astype(np.int16)
69
+ return wav
70
+
71
+
72
+ def encode_audio_common(
73
+ frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1
74
+ ):
75
+ """Return base64 encoded audio"""
76
+ wav_buf = io.BytesIO()
77
+ with wave.open(wav_buf, "wb") as vfout:
78
+ vfout.setnchannels(channels)
79
+ vfout.setsampwidth(sample_width)
80
+ vfout.setframerate(sample_rate)
81
+ vfout.writeframes(frame_input)
82
+
83
+ wav_buf.seek(0)
84
+ if encode_base64:
85
+ b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8")
86
+ return b64_encoded
87
+ else:
88
+ return wav_buf.read()
89
+
90
+
91
+ class StreamingInputs(BaseModel):
92
+ speaker_embedding: List[float]
93
+ gpt_cond_latent: List[List[float]]
94
+ text: str
95
+ language: str
96
+ add_wav_header: bool = True
97
+ stream_chunk_size: str = "20"
98
+
99
+ #
100
+ #def predict_streaming_generator(parsed_input: dict = Body(...)):
101
+ # speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
102
+ # gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
103
+ # text = parsed_input.text
104
+ # language = parsed_input.language
105
+ #
106
+ # stream_chunk_size = int(parsed_input.stream_chunk_size)
107
+ # add_wav_header = parsed_input.add_wav_header
108
+ #
109
+ #
110
+ # chunks = model.inference_stream(
111
+ # text,
112
+ # language,
113
+ # gpt_cond_latent,
114
+ # speaker_embedding,
115
+ # stream_chunk_size=stream_chunk_size,
116
+ # enable_text_splitting=True
117
+ # )
118
+ #
119
+ # for i, chunk in enumerate(chunks):
120
+ # chunk = postprocess(chunk)
121
+ # if i == 0 and add_wav_header:
122
+ # yield encode_audio_common(b"", encode_base64=False)
123
+ # yield chunk.tobytes()
124
+ # else:
125
+ # yield chunk.tobytes()
126
+ #
127
+ #
128
+ ## @app.post("/tts_stream")
129
+ #def predict_streaming_endpoint(parsed_input: StreamingInputs):
130
+ # return StreamingResponse(
131
+ # predict_streaming_generator(parsed_input),
132
+ # media_type="audio/wav",
133
+ # )
134
+
135
+ class TTSInputs(BaseModel):
136
+ speaker_embedding: List[float]
137
+ gpt_cond_latent: List[List[float]]
138
+ text: str
139
+ language: str
140
+ temperature: float
141
+ speed: float
142
+ top_k: int
143
+ top_p: float
144
+
145
+ # @app.post("/tts")
146
+ def predict_speech(parsed_input: TTSInputs):
147
+ speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
148
+ gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
149
+ text = parsed_input.text
150
+ language = parsed_input.language
151
+ temperature = parsed_input.temperature
152
+ speed = parsed_input.speed
153
+ top_k = parsed_input.top_k
154
+ top_p = parsed_input.top_p
155
+ length_penalty = 1.0
156
+ repetition_penalty= 2.0
157
+
158
+
159
+ out = model.inference(
160
+ text,
161
+ language,
162
+ gpt_cond_latent,
163
+ speaker_embedding,
164
+ temperature,
165
+ length_penalty,
166
+ repetition_penalty,
167
+ top_k,
168
+ top_p,
169
+ speed,
170
+ )
171
+
172
+ wav = postprocess(torch.tensor(out["wav"]))
173
+
174
+ return encode_audio_common(wav.tobytes())
175
+
176
+
177
+ # @app.get("/studio_speakers")
178
+ def get_speakers():
179
+ if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
180
+ return {
181
+ speaker: {
182
+ "speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
183
+ "gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
184
+ }
185
+ for speaker in model.speaker_manager.speakers.keys()
186
+ }
187
+ else:
188
+ return {}
189
+
190
+ # @app.get("/languages")
191
+ def get_languages():
192
+ return config.languages