sasan commited on
Commit
b3d61a3
·
1 Parent(s): 78e760c

chore: Update TTS dependencies and add MeloTTS support

Browse files
Files changed (4) hide show
  1. kitt/core/__init__.py +4 -1
  2. kitt/core/tts.py +29 -9
  3. kitt/skills/weather.py +1 -1
  4. main.py +5 -4
kitt/core/__init__.py CHANGED
@@ -6,7 +6,7 @@ from typing import List
6
 
7
  import numpy as np
8
  import torch
9
- from TTS.api import TTS
10
 
11
  os.environ["COQUI_TOS_AGREED"] = "1"
12
 
@@ -17,6 +17,9 @@ Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
17
  file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
18
 
19
  voices = [
 
 
 
20
  Voice(
21
  "Attenborough",
22
  neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
 
6
 
7
  import numpy as np
8
  import torch
9
+ # from TTS.api import TTS
10
 
11
  os.environ["COQUI_TOS_AGREED"] = "1"
12
 
 
17
  file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
18
 
19
  voices = [
20
+ Voice(
21
+ "Fast", neutral=None, angry=None, speed=1.0,
22
+ ),
23
  Voice(
24
  "Attenborough",
25
  neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
kitt/core/tts.py CHANGED
@@ -3,15 +3,21 @@ from replicate import Client
3
  from loguru import logger
4
  from kitt.skills.common import config
5
  import torch
6
- from parler_tts import ParlerTTSForConditionalGeneration
 
7
  from transformers import AutoTokenizer, set_seed
8
  import soundfile as sf
 
 
9
 
10
  replicate = Client(api_token=config.REPLICATE_API_KEY)
11
 
12
  Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
13
 
14
  voices_replicate = [
 
 
 
15
  Voice(
16
  "Attenborough",
17
  neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
@@ -44,6 +50,7 @@ voices_replicate = [
44
  ),
45
  ]
46
 
 
47
  def voice_from_text(voice, voices):
48
  for v in voices:
49
  if voice == f"{v.name} - Neutral":
@@ -64,11 +71,7 @@ def speed_from_text(voice, voices):
64
  def run_tts_replicate(text: str, voice_character: str):
65
  voice = voice_from_text(voice_character, voices_replicate)
66
 
67
- input = {
68
- "text": text,
69
- "speaker": voice,
70
- "cleanup_voice": True
71
- }
72
 
73
  output = replicate.run(
74
  # "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
@@ -82,12 +85,13 @@ def run_tts_replicate(text: str, voice_character: str):
82
  def get_fast_tts():
83
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
84
 
85
- model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-expresso").to(device)
 
 
86
  tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
87
  return model, tokenizer, device
88
 
89
 
90
-
91
  fast_tts = get_fast_tts()
92
 
93
 
@@ -100,4 +104,20 @@ def run_tts_fast(text: str):
100
 
101
  generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
102
  audio_arr = generation.cpu().numpy().squeeze()
103
- return model.config.sampling_rate, audio_arr, dict(text=text, voice="Thomas")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from loguru import logger
4
  from kitt.skills.common import config
5
  import torch
6
+
7
+ # from parler_tts import ParlerTTSForConditionalGeneration
8
  from transformers import AutoTokenizer, set_seed
9
  import soundfile as sf
10
+ from melo.api import TTS as MeloTTS
11
+
12
 
13
  replicate = Client(api_token=config.REPLICATE_API_KEY)
14
 
15
  Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
16
 
17
  voices_replicate = [
18
+ Voice(
19
+ "Fast", neutral=None, angry=None, speed=1.0,
20
+ ),
21
  Voice(
22
  "Attenborough",
23
  neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
 
50
  ),
51
  ]
52
 
53
+
54
  def voice_from_text(voice, voices):
55
  for v in voices:
56
  if voice == f"{v.name} - Neutral":
 
71
  def run_tts_replicate(text: str, voice_character: str):
72
  voice = voice_from_text(voice_character, voices_replicate)
73
 
74
+ input = {"text": text, "speaker": voice, "cleanup_voice": True}
 
 
 
 
75
 
76
  output = replicate.run(
77
  # "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
 
85
  def get_fast_tts():
86
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
87
 
88
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
89
+ "parler-tts/parler-tts-mini-expresso"
90
+ ).to(device)
91
  tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
92
  return model, tokenizer, device
93
 
94
 
 
95
  fast_tts = get_fast_tts()
96
 
97
 
 
104
 
105
  generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
106
  audio_arr = generation.cpu().numpy().squeeze()
107
+ return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
108
+
109
+
110
+ def load_melo_tts():
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+ model = MeloTTS(language="EN", device=device)
113
+ return model
114
+
115
+
116
+ melo_tts = load_melo_tts()
117
+
118
+
119
+ def run_melo_tts(text: str, voice: str):
120
+ speed = 1.0
121
+ speaker_ids = melo_tts.hps.data.spk2id
122
+ audio = melo_tts.tts_to_file(text, speaker_ids["EN-Default"], None, speed=speed)
123
+ return melo_tts.hps.data.sampling_rate, audio
kitt/skills/weather.py CHANGED
@@ -129,7 +129,7 @@ def get_forecast(city_name: str = "", when=0, **kwargs):
129
  number_str = f"in {when-1} days"
130
 
131
  # Generate a sentence for the day's forecast
132
- forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}°C and a low of {min_temp_c}°C. There's a {chance_of_rain}% chance of rain. "
133
 
134
  # number = number + 1
135
  # Add the sentence to the result
 
129
  number_str = f"in {when-1} days"
130
 
131
  # Generate a sentence for the day's forecast
132
+ forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}C and a low of {min_temp_c}C. There's a {chance_of_rain}% chance of rain. "
133
 
134
  # number = number + 1
135
  # Add the sentence to the result
main.py CHANGED
@@ -8,7 +8,7 @@ import typer
8
 
9
  from kitt.skills.common import config, vehicle
10
  from kitt.skills.routing import calculate_route
11
- from kitt.core.tts import run_tts_replicate, run_tts_fast
12
  import ollama
13
 
14
  from langchain.tools.base import StructuredTool
@@ -196,7 +196,7 @@ def run_nexusraven_model(query, voice_character, state):
196
 
197
  if type(output_text) == tuple:
198
  output_text = output_text[0]
199
- gr.Info(f"Output text: {output_text}, generating voice output...")
200
  return (
201
  output_text,
202
  tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
@@ -216,11 +216,12 @@ def run_llama3_model(query, voice_character, state):
216
  functions=functions,
217
  backend=state["llm_backend"],
218
  )
219
- gr.Info(f"Output text: {output_text}, generating voice output...")
220
  voice_out = None
221
  if state["tts_enabled"]:
222
  # voice_out = run_tts_replicate(output_text, voice_character)
223
- voice_out = run_tts_fast(output_text)[0]
 
224
  # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
225
  return (
226
  output_text,
 
8
 
9
  from kitt.skills.common import config, vehicle
10
  from kitt.skills.routing import calculate_route
11
+ from kitt.core.tts import run_tts_replicate, run_tts_fast, run_melo_tts
12
  import ollama
13
 
14
  from langchain.tools.base import StructuredTool
 
196
 
197
  if type(output_text) == tuple:
198
  output_text = output_text[0]
199
+ gr.Info(f"Output text: {output_text}\nGenerating voice output...")
200
  return (
201
  output_text,
202
  tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
 
216
  functions=functions,
217
  backend=state["llm_backend"],
218
  )
219
+ gr.Info(f"Output text: {output_text}\nGenerating voice output...")
220
  voice_out = None
221
  if state["tts_enabled"]:
222
  # voice_out = run_tts_replicate(output_text, voice_character)
223
+ # voice_out = run_tts_fast(output_text)[0]
224
+ voice_out = run_melo_tts(output_text, voice_character)
225
  # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
226
  return (
227
  output_text,