Spaces:
Sleeping
Sleeping
# import base64 | |
# import pathlib | |
# import tempfile | |
import gradio as gr | |
# recorder_js = pathlib.Path('recorder.js').read_text() | |
# main_js = pathlib.Path('main.js').read_text() | |
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace( | |
# 'let main_js = null;', main_js) | |
# def save_base64_video(base64_string): | |
# base64_video = base64_string | |
# video_data = base64.b64decode(base64_video) | |
# with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: | |
# temp_filename = temp_file.name | |
# temp_file.write(video_data) | |
# print(f"Temporary MP4 file saved as: {temp_filename}") | |
# return temp_filename | |
# import os | |
# os.system('python -m unidic download') | |
from transformers import pipeline | |
import numpy as np | |
from VAD.vad_iterator import VADIterator | |
import torch | |
import librosa | |
from mlx_lm import load, stream_generate, generate | |
from LLM.chat import Chat | |
from lightning_whisper_mlx import LightningWhisperMLX | |
from melo.api import TTS | |
LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct") | |
chat = Chat(2) | |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words."}) | |
user_role = "user" | |
tts_model = TTS(language="EN_NEWEST", device="auto") | |
speaker_id = tts_model.hps.data.spk2id["EN-Newest"] | |
blocksize = 512 | |
def int2float(sound): | |
""" | |
Taken from https://github.com/snakers4/silero-vad | |
""" | |
abs_max = np.abs(sound).max() | |
sound = sound.astype("float32") | |
if abs_max > 0: | |
sound *= 1 / 32768 | |
sound = sound.squeeze() # depends on the use case | |
return sound | |
text_str="" | |
audio_output = None | |
min_speech_ms=500 | |
max_speech_ms=float("inf") | |
ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None) | |
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") | |
vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad") | |
vad_iterator = VADIterator( | |
vad_model, | |
threshold=0.3, | |
sampling_rate=16000, | |
min_silence_duration_ms=250, | |
speech_pad_ms=500, | |
) | |
def transcribe(stream, new_chunk): | |
sr, y = new_chunk | |
global text_str | |
global chat | |
global user_role | |
global audio_output | |
audio_int16 = np.frombuffer(y, dtype=np.int16) | |
audio_float32 = int2float(audio_int16) | |
audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000) | |
sr=16000 | |
print(sr) | |
print(audio_float32.shape) | |
vad_output = vad_iterator(torch.from_numpy(audio_float32)) | |
if vad_output is not None and len(vad_output) != 0: | |
print("VAD: end of speech detected") | |
array = torch.cat(vad_output).cpu().numpy() | |
duration_ms = len(array) / sr * 1000 | |
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)): | |
prompt=ASR_model.transcribe(array)["text"].strip() | |
chat.append({"role": user_role, "content": prompt}) | |
chat_messages = chat.to_list() | |
prompt = LM_tokenizer.apply_chat_template( | |
chat_messages, tokenize=False, add_generation_prompt=True | |
) | |
output = generate( | |
LM_model, | |
LM_tokenizer, | |
prompt, | |
max_tokens=128, | |
) | |
# import pdb;pdb.set_trace() | |
generated_text = output.replace("<|end|>", "") | |
torch.mps.empty_cache() | |
chat.append({"role": "assistant", "content": generated_text}) | |
text_str=generated_text | |
# import pdb;pdb.set_trace() | |
audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True) | |
audio_chunk = (audio_chunk * 32768).astype(np.int16) | |
audio_output=(44100, audio_chunk) | |
# else: | |
# audio_output=None | |
text_str1=text_str | |
return stream, text_str1, audio_output | |
demo = gr.Interface( | |
transcribe, | |
["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))], | |
["state", "text", gr.Audio(label="Output", autoplay=True)], | |
live=True, | |
) | |
# with demo: | |
# start_button = gr.Button("Record Screen 🔴") | |
# video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True) | |
# def toggle_button_label(returned_string): | |
# if returned_string.startswith("Record"): | |
# return gr.Button(value="Stop Recording ⚪"), None | |
# else: | |
# try: | |
# temp_filename = save_base64_video(returned_string) | |
# except Exception as e: | |
# return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}') | |
# return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True, | |
# show_share_button=True) | |
# start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js) | |
demo.launch() | |