File size: 5,096 Bytes
330bd18
 
 
f7f39bd
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3200ea6
f7f39bd
129c500
b80895e
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f39bd
966d40f
 
 
 
 
 
 
 
 
 
 
 
330bd18
 
 
 
 
f7f39bd
330bd18
966d40f
 
 
 
 
 
 
 
f7f39bd
 
 
330bd18
 
 
 
 
966d40f
 
330bd18
 
 
 
f92ca2e
330bd18
966d40f
330bd18
966d40f
f5a084e
966d40f
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f39bd
 
 
f5a084e
330bd18
f7f39bd
 
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f39bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# import base64
# import pathlib
# import tempfile
import gradio as gr

# recorder_js = pathlib.Path('recorder.js').read_text()
# main_js = pathlib.Path('main.js').read_text()
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
#     'let main_js = null;', main_js)


# def save_base64_video(base64_string):
#     base64_video = base64_string
#     video_data = base64.b64decode(base64_video)
#     with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
#         temp_filename = temp_file.name
#         temp_file.write(video_data)
#     print(f"Temporary MP4 file saved as: {temp_filename}")
#     return temp_filename
# import os

# os.system('python -m unidic download')
from transformers import pipeline
import numpy as np
from VAD.vad_iterator import VADIterator
import torch
import librosa
from mlx_lm import load, stream_generate, generate
from LLM.chat import Chat
from lightning_whisper_mlx import LightningWhisperMLX
from melo.api import TTS

LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct")
chat = Chat(2)
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words."})
user_role = "user"

tts_model = TTS(language="EN_NEWEST", device="auto")
speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
blocksize = 512

def int2float(sound):
    """
    Taken from https://github.com/snakers4/silero-vad
    """

    abs_max = np.abs(sound).max()
    sound = sound.astype("float32")
    if abs_max > 0:
        sound *= 1 / 32768
    sound = sound.squeeze()  # depends on the use case
    return sound

text_str=""
audio_output = None
min_speech_ms=500
max_speech_ms=float("inf")
ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad")
vad_iterator = VADIterator(
    vad_model,
    threshold=0.3,
    sampling_rate=16000,
    min_silence_duration_ms=250,
    speech_pad_ms=500,
)


def transcribe(stream, new_chunk):
    sr, y = new_chunk
    global text_str
    global chat
    global user_role
    global audio_output
    
    audio_int16 = np.frombuffer(y, dtype=np.int16)
    audio_float32 = int2float(audio_int16)
    audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000)
    sr=16000
    print(sr)
    print(audio_float32.shape)
    vad_output = vad_iterator(torch.from_numpy(audio_float32))
    
    if vad_output is not None and len(vad_output) != 0:
        print("VAD: end of speech detected")
        array = torch.cat(vad_output).cpu().numpy()
        duration_ms = len(array) / sr * 1000
        if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
            prompt=ASR_model.transcribe(array)["text"].strip()
            chat.append({"role": user_role, "content": prompt})
            chat_messages = chat.to_list()
            prompt = LM_tokenizer.apply_chat_template(
                chat_messages, tokenize=False, add_generation_prompt=True
            )
            output = generate(
                LM_model,
                LM_tokenizer,
                prompt,
                max_tokens=128,
            )
        # import pdb;pdb.set_trace()
        generated_text = output.replace("<|end|>", "")
        torch.mps.empty_cache()

        chat.append({"role": "assistant", "content": generated_text})
        text_str=generated_text
        # import pdb;pdb.set_trace()
        audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
        audio_chunk = (audio_chunk * 32768).astype(np.int16)
        audio_output=(44100, audio_chunk)
    # else:
    #     audio_output=None
    text_str1=text_str
    
    return stream, text_str1, audio_output

demo = gr.Interface(
    transcribe,
    ["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
    ["state", "text", gr.Audio(label="Output", autoplay=True)],
    live=True,
)
# with demo:
#     start_button = gr.Button("Record Screen 🔴")
#     video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True)


#     def toggle_button_label(returned_string):
#         if returned_string.startswith("Record"):
#             return gr.Button(value="Stop Recording ⚪"), None
#         else:
#             try:
#                 temp_filename = save_base64_video(returned_string)
#             except Exception as e:
#                 return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}')
#             return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True,
#                                                                 show_share_button=True)
#     start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js)
demo.launch()