gauravyad87 commited on
Commit
4ce7dc8
1 Parent(s): f2016d3

Add updated Dockerfile and app.py

Browse files
Files changed (6) hide show
  1. Dockerfile +36 -0
  2. api.py +195 -0
  3. app.py +226 -0
  4. requirements.txt +25 -0
  5. resources/output.wav +0 -0
  6. se_extractor.py +139 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PORT=8080
7
+
8
+ # Create a non-root user
9
+ RUN useradd -m -u 1000 user
10
+ USER user
11
+ ENV PATH="/home/user/.local/bin:$PATH"
12
+
13
+ WORKDIR /app
14
+
15
+ # Install system dependencies
16
+ RUN apt-get update && apt-get install -y \
17
+ build-essential \
18
+ libsndfile1 \
19
+ ffmpeg \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Copy requirements and install Python dependencies
23
+ COPY --chown=user:users requirements.txt /app/requirements.txt
24
+ RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy the rest of the project files
27
+ COPY --chown=user:users . /app
28
+
29
+ # Ensure outputs and temp directories exist
30
+ RUN mkdir -p outputs temp
31
+
32
+ # Expose port 8080
33
+ EXPOSE 8080
34
+
35
+ # Run the application
36
+ CMD ["python", "app.py"]
api.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import re
4
+ import soundfile
5
+ import utils
6
+ import commons
7
+ import os
8
+ import librosa
9
+ from text import text_to_sequence
10
+ from mel_processing import spectrogram_torch
11
+ from models import SynthesizerTrn
12
+
13
+
14
+ class OpenVoiceBaseClass(object):
15
+ def __init__(self,
16
+ config_path,
17
+ device='cuda:0'):
18
+ if 'cuda' in device:
19
+ assert torch.cuda.is_available()
20
+
21
+ hps = utils.get_hparams_from_file(config_path)
22
+
23
+ model = SynthesizerTrn(
24
+ len(getattr(hps, 'symbols', [])),
25
+ hps.data.filter_length // 2 + 1,
26
+ n_speakers=hps.data.n_speakers,
27
+ **hps.model,
28
+ ).to(device)
29
+
30
+ model.eval()
31
+ self.model = model
32
+ self.hps = hps
33
+ self.device = device
34
+
35
+ def load_ckpt(self, ckpt_path):
36
+ checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
37
+ a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
38
+ print("Loaded checkpoint '{}'".format(ckpt_path))
39
+ print('missing/unexpected keys:', a, b)
40
+
41
+
42
+ class BaseSpeakerTTS(OpenVoiceBaseClass):
43
+ language_marks = {
44
+ "english": "EN",
45
+ "chinese": "ZH",
46
+ }
47
+
48
+ @staticmethod
49
+ def get_text(text, hps, is_symbol):
50
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
51
+ if hps.data.add_blank:
52
+ text_norm = commons.intersperse(text_norm, 0)
53
+ text_norm = torch.LongTensor(text_norm)
54
+ return text_norm
55
+
56
+ @staticmethod
57
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
+ audio_segments = []
59
+ for segment_data in segment_data_list:
60
+ audio_segments += segment_data.reshape(-1).tolist()
61
+ audio_segments += [0] * int((sr * 0.05)/speed)
62
+ audio_segments = np.array(audio_segments).astype(np.float32)
63
+ return audio_segments
64
+
65
+ @staticmethod
66
+ def split_sentences_into_pieces(text, language_str):
67
+ texts = utils.split_sentence(text, language_str=language_str)
68
+ print(" > Text splitted to sentences.")
69
+ print('\n'.join(texts))
70
+ print(" > ===========================")
71
+ return texts
72
+
73
+ def tts(self, text, output_path, speaker, language='English', speed=1.0):
74
+ mark = self.language_marks.get(language.lower(), None)
75
+ assert mark is not None, f"language {language} is not supported"
76
+
77
+ texts = self.split_sentences_into_pieces(text, mark)
78
+
79
+ audio_list = []
80
+ for t in texts:
81
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
82
+ t = f'[{mark}]{t}[{mark}]'
83
+ stn_tst = self.get_text(t, self.hps, False)
84
+ device = self.device
85
+ speaker_id = self.hps.speakers[speaker]
86
+ with torch.no_grad():
87
+ x_tst = stn_tst.unsqueeze(0).to(device)
88
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
89
+ sid = torch.LongTensor([speaker_id]).to(device)
90
+ audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
91
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
92
+ audio_list.append(audio)
93
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
94
+
95
+ if output_path is None:
96
+ return audio
97
+ else:
98
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
99
+
100
+
101
+ class ToneColorConverter(OpenVoiceBaseClass):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__(*args, **kwargs)
104
+
105
+ self.watermark_model = None
106
+
107
+ def extract_se(self, ref_wav_list, se_save_path=None):
108
+ if isinstance(ref_wav_list, str):
109
+ ref_wav_list = [ref_wav_list]
110
+
111
+ device = self.device
112
+ hps = self.hps
113
+ gs = []
114
+
115
+ for fname in ref_wav_list:
116
+ audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
117
+ y = torch.FloatTensor(audio_ref)
118
+ y = y.to(device)
119
+ y = y.unsqueeze(0)
120
+ y = spectrogram_torch(y, hps.data.filter_length,
121
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
122
+ center=False).to(device)
123
+ with torch.no_grad():
124
+ g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
125
+ gs.append(g.detach())
126
+ gs = torch.stack(gs).mean(0)
127
+
128
+ if se_save_path is not None:
129
+ os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
130
+ torch.save(gs.cpu(), se_save_path)
131
+
132
+ return gs
133
+
134
+ def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
135
+ hps = self.hps
136
+ # load audio
137
+ audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
138
+ audio = torch.tensor(audio).float()
139
+
140
+ with torch.no_grad():
141
+ y = torch.FloatTensor(audio).to(self.device)
142
+ y = y.unsqueeze(0)
143
+ spec = spectrogram_torch(y, hps.data.filter_length,
144
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
145
+ center=False).to(self.device)
146
+ spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
147
+ audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
148
+ 0, 0].data.cpu().float().numpy()
149
+ audio = self.add_watermark(audio, message)
150
+ if output_path is None:
151
+ return audio
152
+ else:
153
+ soundfile.write(output_path, audio, hps.data.sampling_rate)
154
+
155
+ def add_watermark(self, audio, message):
156
+ if self.watermark_model is None:
157
+ return audio
158
+ device = self.device
159
+ bits = utils.string_to_bits(message).reshape(-1)
160
+ n_repeat = len(bits) // 32
161
+
162
+ K = 16000
163
+ coeff = 2
164
+ for n in range(n_repeat):
165
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
166
+ if len(trunck) != K:
167
+ print('Audio too short, fail to add watermark')
168
+ break
169
+ message_npy = bits[n * 32: (n + 1) * 32]
170
+
171
+ with torch.no_grad():
172
+ signal = torch.FloatTensor(trunck).to(device)[None]
173
+ message_tensor = torch.FloatTensor(message_npy).to(device)[None]
174
+ signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
175
+ signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
176
+ audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
177
+ return audio
178
+
179
+ def detect_watermark(self, audio, n_repeat):
180
+ bits = []
181
+ K = 16000
182
+ coeff = 2
183
+ for n in range(n_repeat):
184
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
185
+ if len(trunck) != K:
186
+ print('Audio too short, fail to detect watermark')
187
+ return 'Fail'
188
+ with torch.no_grad():
189
+ signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
190
+ message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
191
+ bits.append(message_decoded_npy)
192
+ bits = np.stack(bits).reshape(-1, 8)
193
+ message = utils.bits_to_string(bits)
194
+ return message
195
+
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import gradio as gr
5
+ import openai
6
+ from zipfile import ZipFile
7
+ import requests
8
+ import se_extractor
9
+ from api import BaseSpeakerTTS, ToneColorConverter
10
+ import langid
11
+ import traceback
12
+ from dotenv import load_dotenv
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ # Function to download and extract checkpoints
18
+ def download_and_extract_checkpoints():
19
+ zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
20
+ zip_path = "checkpoints.zip"
21
+
22
+ if not os.path.exists("checkpoints"):
23
+ print("Downloading checkpoints...")
24
+ response = requests.get(zip_url, stream=True)
25
+ with open(zip_path, "wb") as zip_file:
26
+ for chunk in response.iter_content(chunk_size=8192):
27
+ if chunk:
28
+ zip_file.write(chunk)
29
+ print("Extracting checkpoints...")
30
+ with ZipFile(zip_path, "r") as zip_ref:
31
+ zip_ref.extractall(".")
32
+ os.remove(zip_path)
33
+ print("Checkpoints are ready.")
34
+
35
+ # Call the function to ensure checkpoints are available
36
+ download_and_extract_checkpoints()
37
+
38
+ # Initialize OpenAI API key
39
+ openai.api_key = os.getenv("OPENAI_API_KEY")
40
+ if not openai.api_key:
41
+ raise ValueError("Please set the OPENAI_API_KEY environment variable.")
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--share", action='store_true', default=False, help="make link public")
45
+ args = parser.parse_args()
46
+
47
+ # Define paths to checkpoints
48
+ en_ckpt_base = 'checkpoints/base_speakers/EN'
49
+ zh_ckpt_base = 'checkpoints/base_speakers/ZH'
50
+ ckpt_converter = 'checkpoints/converter'
51
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
+ output_dir = 'outputs'
53
+ os.makedirs(output_dir, exist_ok=True)
54
+
55
+ # Load TTS models
56
+ en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
57
+ en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
58
+ zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
59
+ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
60
+
61
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
62
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
63
+
64
+ # Load speaker embeddings
65
+ en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
66
+ en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
67
+ zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
68
+
69
+ # Extract speaker embedding from the default Mickey Mouse audio
70
+ default_speaker_audio = "resources/output.wav"
71
+ try:
72
+ target_se, _ = se_extractor.get_se(
73
+ default_speaker_audio,
74
+ tone_color_converter,
75
+ target_dir='processed',
76
+ vad=True
77
+ )
78
+ print("Speaker embedding extracted successfully.")
79
+ except Exception as e:
80
+ raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
81
+
82
+ # Supported languages
83
+ supported_languages = ['zh', 'en']
84
+
85
+ def predict(audio_file_pth, agree):
86
+ text_hint = ''
87
+ synthesized_audio_path = None
88
+
89
+ # Agree with the terms
90
+ if not agree:
91
+ text_hint += '[ERROR] Please accept the Terms & Conditions!\n'
92
+ return (text_hint, None)
93
+
94
+ # Check if audio file is provided
95
+ if audio_file_pth is not None:
96
+ speaker_wav = audio_file_pth
97
+ else:
98
+ text_hint += "[ERROR] Please record your voice using the Microphone.\n"
99
+ return (text_hint, None)
100
+
101
+ # Transcribe audio to text using OpenAI Whisper
102
+ try:
103
+ with open(speaker_wav, 'rb') as audio_file:
104
+ transcription_response = openai.Audio.transcribe(
105
+ model="whisper-1",
106
+ file=audio_file,
107
+ response_format='text'
108
+ )
109
+ input_text = transcription_response.strip()
110
+ print(f"Transcribed Text: {input_text}")
111
+ except Exception as e:
112
+ text_hint += f"[ERROR] Transcription failed: {str(e)}\n"
113
+ return (text_hint, None)
114
+
115
+ if len(input_text) == 0:
116
+ text_hint += "[ERROR] No speech detected in the audio.\n"
117
+ return (text_hint, None)
118
+
119
+ # Detect language
120
+ language_predicted = langid.classify(input_text)[0].strip()
121
+ print(f"Detected language: {language_predicted}")
122
+
123
+ if language_predicted not in supported_languages:
124
+ text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
125
+ return (text_hint, None)
126
+
127
+ # Select TTS model based on language
128
+ if language_predicted == "zh":
129
+ tts_model = zh_base_speaker_tts
130
+ language = 'Chinese'
131
+ speaker_style = 'default'
132
+ else:
133
+ tts_model = en_base_speaker_tts
134
+ language = 'English'
135
+ speaker_style = 'default'
136
+
137
+ # Generate response using OpenAI GPT-4
138
+ try:
139
+ response = openai.ChatCompletion.create(
140
+ model="gpt-4o-mini",
141
+ messages=[
142
+ {"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."},
143
+ {"role": "user", "content": input_text}
144
+ ],
145
+ max_tokens=200,
146
+ temperature=0.7,
147
+ )
148
+ reply_text = response['choices'][0]['message']['content'].strip()
149
+ print(f"GPT-4 Reply: {reply_text}")
150
+ except Exception as e:
151
+ text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
152
+ return (text_hint, None)
153
+
154
+ # Synthesize reply text to audio
155
+ try:
156
+ src_path = os.path.join(output_dir, 'tmp_reply.wav')
157
+
158
+ tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
159
+ print(f"Audio synthesized and saved to {src_path}")
160
+
161
+ save_path = os.path.join(output_dir, 'output_reply.wav')
162
+
163
+ tone_color_converter.convert(
164
+ audio_src_path=src_path,
165
+ src_se=en_source_default_se if language == 'English' else zh_source_se,
166
+ tgt_se=target_se,
167
+ output_path=save_path,
168
+ message="@MickeyMouse"
169
+ )
170
+ print(f"Tone color conversion completed and saved to {save_path}")
171
+
172
+ text_hint += "Response generated successfully.\n"
173
+ synthesized_audio_path = save_path
174
+
175
+ except Exception as e:
176
+ text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
177
+ traceback.print_exc()
178
+ return (text_hint, None)
179
+
180
+ return (text_hint, synthesized_audio_path)
181
+
182
+ with gr.Blocks(analytics_enabled=False) as demo:
183
+ gr.Markdown("# Mickey Mouse Voice Assistant")
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ audio_input = gr.Audio(
188
+ source="microphone",
189
+ type="filepath",
190
+ label="Record Your Voice",
191
+ info="Click the microphone button to record your voice."
192
+ )
193
+ tos_checkbox = gr.Checkbox(
194
+ label="Agree to Terms & Conditions",
195
+ value=False,
196
+ info="I agree to the terms of service."
197
+ )
198
+ submit_button = gr.Button("Send")
199
+
200
+ with gr.Column():
201
+ info_output = gr.Textbox(
202
+ label="Info",
203
+ interactive=False,
204
+ lines=4,
205
+ )
206
+ audio_output = gr.Audio(
207
+ label="Mickey's Response",
208
+ interactive=False,
209
+ autoplay=True,
210
+ )
211
+
212
+ submit_button.click(
213
+ predict,
214
+ inputs=[audio_input, tos_checkbox],
215
+ outputs=[info_output, audio_output]
216
+ )
217
+
218
+ # Launch the Gradio app
219
+ demo.queue()
220
+ demo.launch(
221
+ server_name="0.0.0.0",
222
+ server_port=int(os.environ.get("PORT", 8080)),
223
+ debug=True,
224
+ show_api=True,
225
+ share=False
226
+ )
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.1
2
+ faster-whisper==0.9.0
3
+ pydub==0.25.1
4
+ wavmark==0.0.2
5
+ numpy==1.22.0
6
+ eng_to_ipa==0.0.2
7
+ inflect==7.0.0
8
+ unidecode==1.3.7
9
+ whisper-timestamped==1.14.2
10
+ openai
11
+ python-dotenv
12
+ pypinyin==0.50.0
13
+ cn2an==0.5.22
14
+ jieba==0.42.1
15
+ gradio==3.50.2
16
+ ffmpeg-python
17
+ fastapi
18
+ uvicorn
19
+ torch
20
+ langid
21
+ requests
22
+ fastapi
23
+ uvicorn[standard]
24
+ webrtcvad
25
+
resources/output.wav ADDED
Binary file (508 kB). View file
 
se_extractor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ from glob import glob
5
+ import numpy as np
6
+ from pydub import AudioSegment
7
+ from faster_whisper import WhisperModel
8
+ from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
9
+
10
+ model_size = "medium"
11
+ # Run on GPU with FP16
12
+ model = None
13
+ def split_audio_whisper(audio_path, target_dir='processed'):
14
+ global model
15
+ if model is None:
16
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
17
+ audio = AudioSegment.from_file(audio_path)
18
+ max_len = len(audio)
19
+
20
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
21
+ target_folder = os.path.join(target_dir, audio_name)
22
+
23
+ segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
24
+ segments = list(segments)
25
+
26
+ # create directory
27
+ os.makedirs(target_folder, exist_ok=True)
28
+ wavs_folder = os.path.join(target_folder, 'wavs')
29
+ os.makedirs(wavs_folder, exist_ok=True)
30
+
31
+ # segments
32
+ s_ind = 0
33
+ start_time = None
34
+
35
+ for k, w in enumerate(segments):
36
+ # process with the time
37
+ if k == 0:
38
+ start_time = max(0, w.start)
39
+
40
+ end_time = w.end
41
+
42
+ # calculate confidence
43
+ if len(w.words) > 0:
44
+ confidence = sum([s.probability for s in w.words]) / len(w.words)
45
+ else:
46
+ confidence = 0.
47
+ # clean text
48
+ text = w.text.replace('...', '')
49
+
50
+ # left 0.08s for each audios
51
+ audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
52
+
53
+ # segment file name
54
+ fname = f"{audio_name}_seg{s_ind}.wav"
55
+
56
+ # filter out the segment shorter than 1.5s and longer than 20s
57
+ save = audio_seg.duration_seconds > 1.5 and \
58
+ audio_seg.duration_seconds < 20. and \
59
+ len(text) >= 2 and len(text) < 200
60
+
61
+ if save:
62
+ output_file = os.path.join(wavs_folder, fname)
63
+ audio_seg.export(output_file, format='wav')
64
+
65
+ if k < len(segments) - 1:
66
+ start_time = max(0, segments[k+1].start - 0.08)
67
+
68
+ s_ind = s_ind + 1
69
+ return wavs_folder
70
+
71
+
72
+ def split_audio_vad(audio_path, target_dir, split_seconds=10.0):
73
+ SAMPLE_RATE = 16000
74
+ audio_vad = get_audio_tensor(audio_path)
75
+ segments = get_vad_segments(
76
+ audio_vad,
77
+ output_sample=True,
78
+ min_speech_duration=0.1,
79
+ min_silence_duration=1,
80
+ method="silero",
81
+ )
82
+ segments = [(seg["start"], seg["end"]) for seg in segments]
83
+ segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
84
+ print(segments)
85
+ audio_active = AudioSegment.silent(duration=0)
86
+ audio = AudioSegment.from_file(audio_path)
87
+
88
+ for start_time, end_time in segments:
89
+ audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
90
+
91
+ audio_dur = audio_active.duration_seconds
92
+ print(f'after vad: dur = {audio_dur}')
93
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
94
+ target_folder = os.path.join(target_dir, audio_name)
95
+ wavs_folder = os.path.join(target_folder, 'wavs')
96
+ os.makedirs(wavs_folder, exist_ok=True)
97
+ start_time = 0.
98
+ count = 0
99
+ num_splits = int(np.round(audio_dur / split_seconds))
100
+ assert num_splits > 0, 'input audio is too short'
101
+ interval = audio_dur / num_splits
102
+
103
+ for i in range(num_splits):
104
+ end_time = min(start_time + interval, audio_dur)
105
+ if i == num_splits - 1:
106
+ end_time = audio_dur
107
+ output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
108
+ audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
109
+ audio_seg.export(output_file, format='wav')
110
+ start_time = end_time
111
+ count += 1
112
+ return wavs_folder
113
+
114
+
115
+
116
+
117
+
118
+ def get_se(audio_path, vc_model, target_dir='processed', vad=True):
119
+ device = vc_model.device
120
+
121
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
122
+ se_path = os.path.join(target_dir, audio_name, 'se.pth')
123
+
124
+ if os.path.isfile(se_path):
125
+ se = torch.load(se_path).to(device)
126
+ return se, audio_name
127
+ if os.path.isdir(audio_path):
128
+ wavs_folder = audio_path
129
+ elif vad:
130
+ wavs_folder = split_audio_vad(audio_path, target_dir)
131
+ else:
132
+ wavs_folder = split_audio_whisper(audio_path, target_dir)
133
+
134
+ audio_segs = glob(f'{wavs_folder}/*.wav')
135
+ if len(audio_segs) == 0:
136
+ raise NotImplementedError('No audio segments found!')
137
+
138
+ return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
139
+