File size: 13,202 Bytes
330bd18
 
 
925a881
 
a5ee5dc
 
a8bef62
c6140ce
f7f39bd
a9750b3
 
58d5604
 
 
a9750b3
d2e6746
 
a9750b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f39bd
129c500
b80895e
330bd18
8071afe
330bd18
8071afe
 
 
 
 
 
 
8fd819e
8071afe
bd1d7fa
330bd18
8071afe
330bd18
7799043
330bd18
 
bd1d7fa
 
330bd18
bd1d7fa
 
 
8ed40d7
764da9e
20a87d9
ccc910d
 
407e9c5
 
 
 
 
 
daefe1f
407e9c5
 
 
 
 
 
8ed40d7
 
 
 
407e9c5
 
8ed40d7
 
 
966d40f
 
 
 
 
 
 
 
 
 
 
 
330bd18
ad2eea0
330bd18
 
 
8071afe
16c6824
 
 
 
 
1f558cf
6135fe6
8071afe
6135fe6
8b0046a
8071afe
8b0046a
8071afe
8fd819e
 
 
 
 
8ed40d7
 
 
 
 
 
 
 
 
 
 
 
92b286e
 
8ed40d7
8fd819e
 
8ed40d7
 
ad2eea0
 
 
 
 
 
 
 
 
966d40f
4ba2ab9
f7f39bd
 
330bd18
 
 
 
ad2eea0
713120b
 
7799043
713120b
 
55d576a
966d40f
 
330bd18
 
 
 
ad2eea0
 
 
 
 
 
 
9730052
0b47b7c
43d5a33
 
 
 
 
0d5d519
ad2eea0
 
 
 
 
 
 
330bd18
ad2eea0
330bd18
966d40f
f5a084e
966d40f
b9d404b
 
 
 
 
 
 
 
 
 
fa286e7
0b47b7c
 
4ba2ab9
b36b389
daefe1f
1089ab2
4893152
 
daefe1f
407e9c5
0b47b7c
e39cf31
4ba2ab9
8071afe
330bd18
 
8fd819e
c6140ce
8071afe
21e951d
c66a07c
 
 
a0ee327
92b286e
330bd18
c6140ce
 
21e951d
c6140ce
 
 
38ba01f
 
a8bef62
38ba01f
a8bef62
38ba01f
a8bef62
c6140ce
a8bef62
 
c6140ce
a8bef62
38ba01f
 
a8bef62
38ba01f
 
 
 
f9a60f2
 
 
 
 
 
 
 
 
 
 
 
38ba01f
f9a60f2
 
ab79be3
fd27c98
8071afe
 
 
 
a8bef62
 
 
 
 
 
 
330bd18
 
 
 
38ba01f
f7f39bd
 
 
f5a084e
f9a60f2
f7f39bd
 
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925a881
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# import base64
# import pathlib
# import tempfile
import os
os.system("python -m unidic download")
import nltk
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('punkt_tab')
from nltk import sent_tokenize
import gradio as gr
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
lang = 'English'
tag = 'kan-bayashi/ljspeech_vits' #@param ["kan-bayashi/ljspeech_tacotron2", "kan-bayashi/ljspeech_fastspeech", "kan-bayashi/ljspeech_fastspeech2", "kan-bayashi/ljspeech_conformer_fastspeech2", "kan-bayashi/ljspeech_joint_finetune_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_joint_train_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_vits"] {type:"string"}
vocoder_tag = "none"
text2speech = Text2Speech.from_pretrained(
    train_config="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml",
    model_file="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth",
    vocoder_tag=str_or_none(vocoder_tag),
    device="cuda",
    # Only for Tacotron 2 & Transformer
    threshold=0.5,
    # Only for Tacotron 2
    minlenratio=0.0,
    maxlenratio=10.0,
    use_att_constraint=False,
    backward_window=1,
    forward_window=3,
    # Only for FastSpeech & FastSpeech2 & VITS
    speed_control_alpha=1.0,
    # Only for VITS
    noise_scale=0.333,
    noise_scale_dur=0.333,
)

# recorder_js = pathlib.Path('recorder.js').read_text()
# main_js = pathlib.Path('main.js').read_text()
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
#     'let main_js = null;', main_js)


# def save_base64_video(base64_string):
#     base64_video = base64_string
#     video_data = base64.b64decode(base64_video)
#     with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
#         temp_filename = temp_file.name
#         temp_file.write(video_data)
#     print(f"Temporary MP4 file saved as: {temp_filename}")
#     return temp_filename
# import os

# os.system('python -m unidic download')
import numpy as np
from VAD.vad_iterator import VADIterator
import torch
import librosa
# from mlx_lm import load, stream_generate, generate
from LLM.chat import Chat
# from lightning_whisper_mlx import LightningWhisperMLX
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    TextIteratorStreamer,
)
# from melo.api import TTS

# LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct")
chat = Chat(2)
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant called Veda. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
user_role = "user"

# tts_model = TTS(language="EN_NEWEST", device="auto")
# speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
blocksize = 512
with torch.no_grad():
    wav = text2speech("Sid")["wav"]
# tts_model.tts_to_file("text", speaker_id, quiet=True)
dummy_input = torch.randn(
        (3000),
        dtype=getattr(torch, "float16"),
        device="cpu",
).cpu().numpy()
import soundfile as sf
import kaldiio
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch


s2t = Speech2TextGreedySearch.from_pretrained(
    "pyf98/owsm_ctc_v3.1_1B",
    device="cuda",
    generate_interctc_outputs=False,
    lang_sym='<eng>',
    task_sym='<asr>',
)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
speech = librosa.util.fix_length(dummy_input, size=(16000 * 30))
res = s2t(speech)
end_event.record()
torch.cuda.synchronize()

def int2float(sound):
    """
    Taken from https://github.com/snakers4/silero-vad
    """

    abs_max = np.abs(sound).max()
    sound = sound.astype("float32")
    if abs_max > 0:
        sound *= 1 / 32768
    sound = sound.squeeze()  # depends on the use case
    return sound

text_str=""
vad_output=None
audio_output = None
min_speech_ms=500
max_speech_ms=float("inf")
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
# ASR_processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3")
# ASR_model = AutoModelForSpeechSeq2Seq.from_pretrained(
#     "distil-whisper/distil-large-v3",
#     torch_dtype="float16",
# ).to("cpu")
access_token = os.environ.get("HF_TOKEN")
LM_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", token=access_token)
LM_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct", torch_dtype="float16", trust_remote_code=True, token=access_token
).to("cuda")
LM_pipe = pipeline(
    "text-generation", model=LM_model, tokenizer=LM_tokenizer, device="cuda"
)
streamer = TextIteratorStreamer(
            LM_tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": "user", "content": dummy_input_text}]
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
LM_pipe(
    dummy_chat,
    max_new_tokens=32,
    min_new_tokens=0,
    temperature=0.0,
    do_sample=False,
    streamer=streamer,
    return_full_text=False,
)
for a in streamer:
    print(a)
end_event.record()
torch.cuda.synchronize()
# vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad")
# vad_iterator = VADIterator(
#     vad_model,
#     threshold=0.3,
#     sampling_rate=16000,
#     min_silence_duration_ms=250,
#     speech_pad_ms=500,
# )
import webrtcvad

import time
def transcribe(stream, new_chunk):
    sr, y = new_chunk
    global text_str
    global chat
    global user_role
    global audio_output
    global vad_output
    if stream is None:
        stream=True
        chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant called Veda. You should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 20 words."})
        text_str=""
        audio_output = None
    orig_sr=sr
    audio_int16 = np.frombuffer(y, dtype=np.int16)
    audio_float32 = int2float(audio_int16)
    audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000)
    sr=16000
    print(sr)
    print(audio_float32.shape)
    # vad_output = vad_iterator(torch.from_numpy(audio_float32))
    vad_count=0
    for i in range(int(len(y)/960)):
        vad = webrtcvad.Vad()
        vad.set_mode(3)
        if (vad.is_speech(y[i*960:(i+1)*960].tobytes(), orig_sr)):
            vad_count+=1
    print(vad_count)
    if vad_output is None and vad_count>12:
        vad_curr=True
        if vad_output is None:
            vad_output=[torch.from_numpy(audio_float32)]
        else:
            vad_output.append(torch.from_numpy(audio_float32))
    elif vad_output is not None and vad_count>10:
        vad_curr=True
        if vad_output is None:
            vad_output=[torch.from_numpy(audio_float32)]
        else:
            vad_output.append(torch.from_numpy(audio_float32))
    else:
        vad_curr=False
    
    if vad_output is not None and vad_curr==False:
        print("VAD: end of speech detected")
        array = torch.cat(vad_output).cpu().numpy()
        duration_ms = len(array) / sr * 1000
        if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
            # input_features = ASR_processor(
            #     array, sampling_rate=16000, return_tensors="pt"
            # ).input_features
            # print(input_features)
            # input_features = input_features.to("cpu", dtype=getattr(torch, "float16"))
            # pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en")
            # print(pred_ids)
            # prompt = ASR_processor.batch_decode(
            #     pred_ids, skip_special_tokens=True, decode_with_timestamps=False
            # )[0]
            print(len(array))
            array = librosa.util.fix_length(array, size=(16000 * 30))
            print(len(array))
            start_time = time.time()
            prompt=" ".join(s2t(array)[0][0].split()[1:])
            vad_output = None
            if len(prompt.strip().split())<2:
                text_str1=text_str    
                return stream, text_str1, audio_output
            
            # prompt=transcriber({"sampling_rate": sr, "raw": array})["text"]
            print(len(prompt.strip().split()))
            print(prompt)
            print("--- %s seconds ---" % (time.time() - start_time))
            # prompt=ASR_model.transcribe(array)["text"].strip()
            chat.append({"role": user_role, "content": prompt})
            chat_messages = chat.to_list()
            
            LM_pipe(
                chat_messages,
                max_new_tokens=256,
                min_new_tokens=0,
                temperature=0.0,
                do_sample=False,
                streamer=streamer,
                return_full_text=False,
            )
            output=""
            curr_output = ""
            text_str = ""
            for t in streamer:
                output += t
                curr_output += t
                sentences=sent_tokenize(curr_output)
                if len(sentences)>1:
                    print("--- %s seconds ---" % (time.time() - start_time))
                    print(sentences[0])
                    with torch.no_grad():
                        audio_chunk = text2speech(sentences[0])["wav"].view(-1).cpu().numpy()
                    # audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
                    audio_chunk = (audio_chunk * 32768).astype(np.int16)
                    print(text2speech.fs)
                    audio_output=(text2speech.fs, audio_chunk)
                    print("okk")
                    # print(audio_chunk)
                    # print(audio_chunk.shape)
                    text_str=text_str+sentences[0]
                    print("--- %s seconds ---" % (time.time() - start_time))
                    yield (stream,text_str, audio_output)
                    time.sleep((len(audio_chunk)/text2speech.fs)-0.2)
                    curr_output = t             
            print("--- %s seconds ---" % (time.time() - start_time))
            print(curr_output)
            with torch.no_grad():
                audio_chunk = text2speech(curr_output)["wav"].view(-1).cpu().numpy()
                # audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
                audio_chunk = (audio_chunk * 32768).astype(np.int16)
                print(text2speech.fs)
                audio_output=(text2speech.fs, audio_chunk)
                print("okk")
                # print(audio_chunk)
                print(audio_chunk.shape)
                print("--- %s seconds ---" % (time.time() - start_time))
                yield (stream,output, audio_output)
                time.sleep((len(audio_chunk)/text2speech.fs)-0.2)
                curr_output = ""
            generated_text = output
            
            # torch.mps.empty_cache()
    
            chat.append({"role": "assistant", "content": generated_text})
            text_str=generated_text
            # import pdb;pdb.set_trace()
            # with torch.no_grad():
            #     audio_chunk = text2speech(text_str)["wav"].view(-1).cpu().numpy()
            # # audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
            # audio_chunk = (audio_chunk * 32768).astype(np.int16)
            # print(text2speech.fs)
            # audio_output=(text2speech.fs, audio_chunk)
            
    # else:
    #     audio_output=None
    text_str1=text_str
    
    yield (stream,text_str1, audio_output)

demo = gr.Interface(
    transcribe,
    ["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
    ["state", "text", gr.Audio(label="Output", autoplay=True,visible=True,)],
    live=True,
)
# with demo:
#     start_button = gr.Button("Record Screen 🔴")
#     video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True)


#     def toggle_button_label(returned_string):
#         if returned_string.startswith("Record"):
#             return gr.Button(value="Stop Recording ⚪"), None
#         else:
#             try:
#                 temp_filename = save_base64_video(returned_string)
#             except Exception as e:
#                 return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}')
#             return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True,
#                                                                 show_share_button=True)
#     start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js)
demo.launch("share=True")