chore: Refactor TTS functionality and dependencies
Browse files- kitt/core/tts.py +0 -28
- main.py +1 -4
kitt/core/tts.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
import copy
|
2 |
from collections import namedtuple
|
3 |
|
4 |
-
import soundfile as sf
|
5 |
import torch
|
6 |
from loguru import logger
|
7 |
-
from parler_tts import ParlerTTSForConditionalGeneration
|
8 |
from replicate import Client
|
9 |
-
from transformers import AutoTokenizer
|
10 |
|
11 |
from kitt.skills.common import config
|
12 |
|
@@ -94,31 +91,6 @@ def run_tts_replicate(text: str, voice_character: str):
|
|
94 |
return output
|
95 |
|
96 |
|
97 |
-
def get_fast_tts():
|
98 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
99 |
-
|
100 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
101 |
-
"parler-tts/parler-tts-mini-expresso"
|
102 |
-
).to(device)
|
103 |
-
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
104 |
-
return model, tokenizer, device
|
105 |
-
|
106 |
-
|
107 |
-
fast_tts = get_fast_tts()
|
108 |
-
|
109 |
-
|
110 |
-
def run_tts_fast(text: str):
|
111 |
-
model, tokenizer, device = fast_tts
|
112 |
-
description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
|
113 |
-
|
114 |
-
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
|
115 |
-
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
116 |
-
|
117 |
-
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
118 |
-
audio_arr = generation.cpu().numpy().squeeze()
|
119 |
-
return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
|
120 |
-
|
121 |
-
|
122 |
def load_melo_tts():
|
123 |
from melo.api import TTS as MeloTTS
|
124 |
|
|
|
1 |
import copy
|
2 |
from collections import namedtuple
|
3 |
|
|
|
4 |
import torch
|
5 |
from loguru import logger
|
|
|
6 |
from replicate import Client
|
|
|
7 |
|
8 |
from kitt.skills.common import config
|
9 |
|
|
|
91 |
return output
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def load_melo_tts():
|
95 |
from melo.api import TTS as MeloTTS
|
96 |
|
main.py
CHANGED
@@ -9,7 +9,7 @@ from kitt.core import utils as kitt_utils
|
|
9 |
from kitt.core import voice_options
|
10 |
from kitt.core.model import generate_function_call as process_query
|
11 |
from kitt.core.stt import save_and_transcribe_audio
|
12 |
-
from kitt.core.tts import prep_for_tts, run_melo_tts,
|
13 |
from kitt.skills import (
|
14 |
code_interpreter,
|
15 |
date_time_info,
|
@@ -118,9 +118,6 @@ def run_llama3_model(query, voice_character, state):
|
|
118 |
voice_out = tts_gradio(
|
119 |
output_text_tts, voice_character, speaker_embedding_cache
|
120 |
)[0]
|
121 |
-
#
|
122 |
-
# voice_out = run_tts_fast(output_text)[0]
|
123 |
-
#
|
124 |
return (
|
125 |
output_text,
|
126 |
voice_out,
|
|
|
9 |
from kitt.core import voice_options
|
10 |
from kitt.core.model import generate_function_call as process_query
|
11 |
from kitt.core.stt import save_and_transcribe_audio
|
12 |
+
from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
|
13 |
from kitt.skills import (
|
14 |
code_interpreter,
|
15 |
date_time_info,
|
|
|
118 |
voice_out = tts_gradio(
|
119 |
output_text_tts, voice_character, speaker_embedding_cache
|
120 |
)[0]
|
|
|
|
|
|
|
121 |
return (
|
122 |
output_text,
|
123 |
voice_out,
|