chore: Update TTS dependencies and add MeloTTS support
Browse files- kitt/core/__init__.py +4 -1
- kitt/core/tts.py +29 -9
- kitt/skills/weather.py +1 -1
- main.py +5 -4
kitt/core/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from typing import List
|
|
6 |
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
-
from TTS.api import TTS
|
10 |
|
11 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
12 |
|
@@ -17,6 +17,9 @@ Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
|
17 |
file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
|
18 |
|
19 |
voices = [
|
|
|
|
|
|
|
20 |
Voice(
|
21 |
"Attenborough",
|
22 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
+
# from TTS.api import TTS
|
10 |
|
11 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
12 |
|
|
|
17 |
file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
|
18 |
|
19 |
voices = [
|
20 |
+
Voice(
|
21 |
+
"Fast", neutral=None, angry=None, speed=1.0,
|
22 |
+
),
|
23 |
Voice(
|
24 |
"Attenborough",
|
25 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
kitt/core/tts.py
CHANGED
@@ -3,15 +3,21 @@ from replicate import Client
|
|
3 |
from loguru import logger
|
4 |
from kitt.skills.common import config
|
5 |
import torch
|
6 |
-
|
|
|
7 |
from transformers import AutoTokenizer, set_seed
|
8 |
import soundfile as sf
|
|
|
|
|
9 |
|
10 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
11 |
|
12 |
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
13 |
|
14 |
voices_replicate = [
|
|
|
|
|
|
|
15 |
Voice(
|
16 |
"Attenborough",
|
17 |
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
|
@@ -44,6 +50,7 @@ voices_replicate = [
|
|
44 |
),
|
45 |
]
|
46 |
|
|
|
47 |
def voice_from_text(voice, voices):
|
48 |
for v in voices:
|
49 |
if voice == f"{v.name} - Neutral":
|
@@ -64,11 +71,7 @@ def speed_from_text(voice, voices):
|
|
64 |
def run_tts_replicate(text: str, voice_character: str):
|
65 |
voice = voice_from_text(voice_character, voices_replicate)
|
66 |
|
67 |
-
input = {
|
68 |
-
"text": text,
|
69 |
-
"speaker": voice,
|
70 |
-
"cleanup_voice": True
|
71 |
-
}
|
72 |
|
73 |
output = replicate.run(
|
74 |
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
|
@@ -82,12 +85,13 @@ def run_tts_replicate(text: str, voice_character: str):
|
|
82 |
def get_fast_tts():
|
83 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
84 |
|
85 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
|
|
|
|
86 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
87 |
return model, tokenizer, device
|
88 |
|
89 |
|
90 |
-
|
91 |
fast_tts = get_fast_tts()
|
92 |
|
93 |
|
@@ -100,4 +104,20 @@ def run_tts_fast(text: str):
|
|
100 |
|
101 |
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
102 |
audio_arr = generation.cpu().numpy().squeeze()
|
103 |
-
return model.config.sampling_rate, audio_arr, dict(text=text, voice="Thomas")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from loguru import logger
|
4 |
from kitt.skills.common import config
|
5 |
import torch
|
6 |
+
|
7 |
+
# from parler_tts import ParlerTTSForConditionalGeneration
|
8 |
from transformers import AutoTokenizer, set_seed
|
9 |
import soundfile as sf
|
10 |
+
from melo.api import TTS as MeloTTS
|
11 |
+
|
12 |
|
13 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
14 |
|
15 |
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
16 |
|
17 |
voices_replicate = [
|
18 |
+
Voice(
|
19 |
+
"Fast", neutral=None, angry=None, speed=1.0,
|
20 |
+
),
|
21 |
Voice(
|
22 |
"Attenborough",
|
23 |
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
|
|
|
50 |
),
|
51 |
]
|
52 |
|
53 |
+
|
54 |
def voice_from_text(voice, voices):
|
55 |
for v in voices:
|
56 |
if voice == f"{v.name} - Neutral":
|
|
|
71 |
def run_tts_replicate(text: str, voice_character: str):
|
72 |
voice = voice_from_text(voice_character, voices_replicate)
|
73 |
|
74 |
+
input = {"text": text, "speaker": voice, "cleanup_voice": True}
|
|
|
|
|
|
|
|
|
75 |
|
76 |
output = replicate.run(
|
77 |
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
|
|
|
85 |
def get_fast_tts():
|
86 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
87 |
|
88 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
89 |
+
"parler-tts/parler-tts-mini-expresso"
|
90 |
+
).to(device)
|
91 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
92 |
return model, tokenizer, device
|
93 |
|
94 |
|
|
|
95 |
fast_tts = get_fast_tts()
|
96 |
|
97 |
|
|
|
104 |
|
105 |
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
106 |
audio_arr = generation.cpu().numpy().squeeze()
|
107 |
+
return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
|
108 |
+
|
109 |
+
|
110 |
+
def load_melo_tts():
|
111 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
112 |
+
model = MeloTTS(language="EN", device=device)
|
113 |
+
return model
|
114 |
+
|
115 |
+
|
116 |
+
melo_tts = load_melo_tts()
|
117 |
+
|
118 |
+
|
119 |
+
def run_melo_tts(text: str, voice: str):
|
120 |
+
speed = 1.0
|
121 |
+
speaker_ids = melo_tts.hps.data.spk2id
|
122 |
+
audio = melo_tts.tts_to_file(text, speaker_ids["EN-Default"], None, speed=speed)
|
123 |
+
return melo_tts.hps.data.sampling_rate, audio
|
kitt/skills/weather.py
CHANGED
@@ -129,7 +129,7 @@ def get_forecast(city_name: str = "", when=0, **kwargs):
|
|
129 |
number_str = f"in {when-1} days"
|
130 |
|
131 |
# Generate a sentence for the day's forecast
|
132 |
-
forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}
|
133 |
|
134 |
# number = number + 1
|
135 |
# Add the sentence to the result
|
|
|
129 |
number_str = f"in {when-1} days"
|
130 |
|
131 |
# Generate a sentence for the day's forecast
|
132 |
+
forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}C and a low of {min_temp_c}C. There's a {chance_of_rain}% chance of rain. "
|
133 |
|
134 |
# number = number + 1
|
135 |
# Add the sentence to the result
|
main.py
CHANGED
@@ -8,7 +8,7 @@ import typer
|
|
8 |
|
9 |
from kitt.skills.common import config, vehicle
|
10 |
from kitt.skills.routing import calculate_route
|
11 |
-
from kitt.core.tts import run_tts_replicate, run_tts_fast
|
12 |
import ollama
|
13 |
|
14 |
from langchain.tools.base import StructuredTool
|
@@ -196,7 +196,7 @@ def run_nexusraven_model(query, voice_character, state):
|
|
196 |
|
197 |
if type(output_text) == tuple:
|
198 |
output_text = output_text[0]
|
199 |
-
gr.Info(f"Output text: {output_text}
|
200 |
return (
|
201 |
output_text,
|
202 |
tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
|
@@ -216,11 +216,12 @@ def run_llama3_model(query, voice_character, state):
|
|
216 |
functions=functions,
|
217 |
backend=state["llm_backend"],
|
218 |
)
|
219 |
-
gr.Info(f"Output text: {output_text}
|
220 |
voice_out = None
|
221 |
if state["tts_enabled"]:
|
222 |
# voice_out = run_tts_replicate(output_text, voice_character)
|
223 |
-
voice_out = run_tts_fast(output_text)[0]
|
|
|
224 |
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
225 |
return (
|
226 |
output_text,
|
|
|
8 |
|
9 |
from kitt.skills.common import config, vehicle
|
10 |
from kitt.skills.routing import calculate_route
|
11 |
+
from kitt.core.tts import run_tts_replicate, run_tts_fast, run_melo_tts
|
12 |
import ollama
|
13 |
|
14 |
from langchain.tools.base import StructuredTool
|
|
|
196 |
|
197 |
if type(output_text) == tuple:
|
198 |
output_text = output_text[0]
|
199 |
+
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
200 |
return (
|
201 |
output_text,
|
202 |
tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
|
|
|
216 |
functions=functions,
|
217 |
backend=state["llm_backend"],
|
218 |
)
|
219 |
+
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
220 |
voice_out = None
|
221 |
if state["tts_enabled"]:
|
222 |
# voice_out = run_tts_replicate(output_text, voice_character)
|
223 |
+
# voice_out = run_tts_fast(output_text)[0]
|
224 |
+
voice_out = run_melo_tts(output_text, voice_character)
|
225 |
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
226 |
return (
|
227 |
output_text,
|