sasan commited on
Commit
68ce4a1
·
1 Parent(s): b945617

chore: Refactor TTS functionality and dependencies

Browse files
Files changed (2) hide show
  1. kitt/core/tts.py +0 -28
  2. main.py +1 -4
kitt/core/tts.py CHANGED
@@ -1,12 +1,9 @@
1
  import copy
2
  from collections import namedtuple
3
 
4
- import soundfile as sf
5
  import torch
6
  from loguru import logger
7
- from parler_tts import ParlerTTSForConditionalGeneration
8
  from replicate import Client
9
- from transformers import AutoTokenizer
10
 
11
  from kitt.skills.common import config
12
 
@@ -94,31 +91,6 @@ def run_tts_replicate(text: str, voice_character: str):
94
  return output
95
 
96
 
97
- def get_fast_tts():
98
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
99
-
100
- model = ParlerTTSForConditionalGeneration.from_pretrained(
101
- "parler-tts/parler-tts-mini-expresso"
102
- ).to(device)
103
- tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
104
- return model, tokenizer, device
105
-
106
-
107
- fast_tts = get_fast_tts()
108
-
109
-
110
- def run_tts_fast(text: str):
111
- model, tokenizer, device = fast_tts
112
- description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
113
-
114
- input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
115
- prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
116
-
117
- generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
118
- audio_arr = generation.cpu().numpy().squeeze()
119
- return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
120
-
121
-
122
  def load_melo_tts():
123
  from melo.api import TTS as MeloTTS
124
 
 
1
  import copy
2
  from collections import namedtuple
3
 
 
4
  import torch
5
  from loguru import logger
 
6
  from replicate import Client
 
7
 
8
  from kitt.skills.common import config
9
 
 
91
  return output
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def load_melo_tts():
95
  from melo.api import TTS as MeloTTS
96
 
main.py CHANGED
@@ -9,7 +9,7 @@ from kitt.core import utils as kitt_utils
9
  from kitt.core import voice_options
10
  from kitt.core.model import generate_function_call as process_query
11
  from kitt.core.stt import save_and_transcribe_audio
12
- from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_fast, run_tts_replicate
13
  from kitt.skills import (
14
  code_interpreter,
15
  date_time_info,
@@ -118,9 +118,6 @@ def run_llama3_model(query, voice_character, state):
118
  voice_out = tts_gradio(
119
  output_text_tts, voice_character, speaker_embedding_cache
120
  )[0]
121
- #
122
- # voice_out = run_tts_fast(output_text)[0]
123
- #
124
  return (
125
  output_text,
126
  voice_out,
 
9
  from kitt.core import voice_options
10
  from kitt.core.model import generate_function_call as process_query
11
  from kitt.core.stt import save_and_transcribe_audio
12
+ from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
13
  from kitt.skills import (
14
  code_interpreter,
15
  date_time_info,
 
118
  voice_out = tts_gradio(
119
  output_text_tts, voice_character, speaker_embedding_cache
120
  )[0]
 
 
 
121
  return (
122
  output_text,
123
  voice_out,