Spaces:
Sleeping
Sleeping
File size: 5,096 Bytes
330bd18 f7f39bd 330bd18 3200ea6 f7f39bd 129c500 b80895e 330bd18 f7f39bd 966d40f 330bd18 f7f39bd 330bd18 966d40f f7f39bd 330bd18 966d40f 330bd18 f92ca2e 330bd18 966d40f 330bd18 966d40f f5a084e 966d40f 330bd18 f7f39bd f5a084e 330bd18 f7f39bd 330bd18 f7f39bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# import base64
# import pathlib
# import tempfile
import gradio as gr
# recorder_js = pathlib.Path('recorder.js').read_text()
# main_js = pathlib.Path('main.js').read_text()
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
# 'let main_js = null;', main_js)
# def save_base64_video(base64_string):
# base64_video = base64_string
# video_data = base64.b64decode(base64_video)
# with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
# temp_filename = temp_file.name
# temp_file.write(video_data)
# print(f"Temporary MP4 file saved as: {temp_filename}")
# return temp_filename
# import os
# os.system('python -m unidic download')
from transformers import pipeline
import numpy as np
from VAD.vad_iterator import VADIterator
import torch
import librosa
from mlx_lm import load, stream_generate, generate
from LLM.chat import Chat
from lightning_whisper_mlx import LightningWhisperMLX
from melo.api import TTS
LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct")
chat = Chat(2)
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words."})
user_role = "user"
tts_model = TTS(language="EN_NEWEST", device="auto")
speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
blocksize = 512
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound
text_str=""
audio_output = None
min_speech_ms=500
max_speech_ms=float("inf")
ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad")
vad_iterator = VADIterator(
vad_model,
threshold=0.3,
sampling_rate=16000,
min_silence_duration_ms=250,
speech_pad_ms=500,
)
def transcribe(stream, new_chunk):
sr, y = new_chunk
global text_str
global chat
global user_role
global audio_output
audio_int16 = np.frombuffer(y, dtype=np.int16)
audio_float32 = int2float(audio_int16)
audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000)
sr=16000
print(sr)
print(audio_float32.shape)
vad_output = vad_iterator(torch.from_numpy(audio_float32))
if vad_output is not None and len(vad_output) != 0:
print("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / sr * 1000
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
prompt=ASR_model.transcribe(array)["text"].strip()
chat.append({"role": user_role, "content": prompt})
chat_messages = chat.to_list()
prompt = LM_tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
output = generate(
LM_model,
LM_tokenizer,
prompt,
max_tokens=128,
)
# import pdb;pdb.set_trace()
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
chat.append({"role": "assistant", "content": generated_text})
text_str=generated_text
# import pdb;pdb.set_trace()
audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
audio_output=(44100, audio_chunk)
# else:
# audio_output=None
text_str1=text_str
return stream, text_str1, audio_output
demo = gr.Interface(
transcribe,
["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
["state", "text", gr.Audio(label="Output", autoplay=True)],
live=True,
)
# with demo:
# start_button = gr.Button("Record Screen 🔴")
# video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True)
# def toggle_button_label(returned_string):
# if returned_string.startswith("Record"):
# return gr.Button(value="Stop Recording ⚪"), None
# else:
# try:
# temp_filename = save_base64_video(returned_string)
# except Exception as e:
# return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}')
# return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True,
# show_share_button=True)
# start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js)
demo.launch()
|