Spaces:
Sleeping
Sleeping
# import base64 | |
# import pathlib | |
# import tempfile | |
import os | |
os.system("python -m unidic download") | |
import nltk | |
nltk.download('averaged_perceptron_tagger_eng') | |
nltk.download('punkt_tab') | |
from nltk import sent_tokenize | |
import gradio as gr | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
lang = 'English' | |
tag = 'kan-bayashi/ljspeech_vits' #@param ["kan-bayashi/ljspeech_tacotron2", "kan-bayashi/ljspeech_fastspeech", "kan-bayashi/ljspeech_fastspeech2", "kan-bayashi/ljspeech_conformer_fastspeech2", "kan-bayashi/ljspeech_joint_finetune_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_joint_train_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_vits"] {type:"string"} | |
vocoder_tag = "none" | |
text2speech = Text2Speech.from_pretrained( | |
train_config="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml", | |
model_file="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth", | |
vocoder_tag=str_or_none(vocoder_tag), | |
device="cuda", | |
# Only for Tacotron 2 & Transformer | |
threshold=0.5, | |
# Only for Tacotron 2 | |
minlenratio=0.0, | |
maxlenratio=10.0, | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
# Only for FastSpeech & FastSpeech2 & VITS | |
speed_control_alpha=1.0, | |
# Only for VITS | |
noise_scale=0.333, | |
noise_scale_dur=0.333, | |
) | |
# recorder_js = pathlib.Path('recorder.js').read_text() | |
# main_js = pathlib.Path('main.js').read_text() | |
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace( | |
# 'let main_js = null;', main_js) | |
# def save_base64_video(base64_string): | |
# base64_video = base64_string | |
# video_data = base64.b64decode(base64_video) | |
# with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: | |
# temp_filename = temp_file.name | |
# temp_file.write(video_data) | |
# print(f"Temporary MP4 file saved as: {temp_filename}") | |
# return temp_filename | |
# import os | |
# os.system('python -m unidic download') | |
import numpy as np | |
from VAD.vad_iterator import VADIterator | |
import torch | |
import librosa | |
# from mlx_lm import load, stream_generate, generate | |
from LLM.chat import Chat | |
# from lightning_whisper_mlx import LightningWhisperMLX | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
pipeline, | |
TextIteratorStreamer, | |
) | |
# from melo.api import TTS | |
# You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words. | |
# LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct") | |
chat = Chat(2) | |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. The user is talking to you with their voice and you should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses."}) | |
user_role = "user" | |
# tts_model = TTS(language="EN_NEWEST", device="auto") | |
# speaker_id = tts_model.hps.data.spk2id["EN-Newest"] | |
blocksize = 512 | |
with torch.no_grad(): | |
wav = text2speech("Sid")["wav"] | |
# tts_model.tts_to_file("text", speaker_id, quiet=True) | |
dummy_input = torch.randn( | |
(3000), | |
dtype=getattr(torch, "float16"), | |
device="cpu", | |
).cpu().numpy() | |
import soundfile as sf | |
import kaldiio | |
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch | |
s2t = Speech2TextGreedySearch.from_pretrained( | |
"pyf98/owsm_ctc_v3.1_1B", | |
device="cuda", | |
generate_interctc_outputs=False, | |
lang_sym='<eng>', | |
task_sym='<asr>', | |
) | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
speech = librosa.util.fix_length(dummy_input, size=(16000 * 30)) | |
res = s2t(speech) | |
end_event.record() | |
torch.cuda.synchronize() | |
def int2float(sound): | |
""" | |
Taken from https://github.com/snakers4/silero-vad | |
""" | |
abs_max = np.abs(sound).max() | |
sound = sound.astype("float32") | |
if abs_max > 0: | |
sound *= 1 / 32768 | |
sound = sound.squeeze() # depends on the use case | |
return sound | |
text_str="" | |
asr_output_str="" | |
vad_output=None | |
audio_output = None | |
min_speech_ms=500 | |
max_speech_ms=float("inf") | |
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None) | |
# ASR_processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3") | |
# ASR_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
# "distil-whisper/distil-large-v3", | |
# torch_dtype="float16", | |
# ).to("cpu") | |
access_token = os.environ.get("HF_TOKEN") | |
LM_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", token=access_token) | |
LM_model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-3.2-1B-Instruct", torch_dtype="float16", trust_remote_code=True, token=access_token | |
).to("cuda") | |
LM_pipe = pipeline( | |
"text-generation", model=LM_model, tokenizer=LM_tokenizer, device="cuda" | |
) | |
streamer = TextIteratorStreamer( | |
LM_tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
) | |
dummy_input_text = "Write me a poem about Machine Learning." | |
dummy_chat = [{"role": "user", "content": dummy_input_text}] | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
LM_pipe( | |
dummy_chat, | |
max_new_tokens=32, | |
min_new_tokens=0, | |
temperature=0.0, | |
do_sample=False, | |
streamer=streamer, | |
return_full_text=False, | |
) | |
for a in streamer: | |
print(a) | |
end_event.record() | |
torch.cuda.synchronize() | |
# vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad") | |
# vad_iterator = VADIterator( | |
# vad_model, | |
# threshold=0.3, | |
# sampling_rate=16000, | |
# min_silence_duration_ms=250, | |
# speech_pad_ms=500, | |
# ) | |
import webrtcvad | |
import time | |
def transcribe(stream, new_chunk): | |
sr, y = new_chunk | |
global text_str | |
global chat | |
global user_role | |
global audio_output | |
global vad_output | |
global asr_output_str | |
if stream is None: | |
stream=True | |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant called Veda. You should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 20 words."}) | |
text_str="" | |
audio_output = None | |
orig_sr=sr | |
audio_int16 = np.frombuffer(y, dtype=np.int16) | |
audio_float32 = int2float(audio_int16) | |
audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000) | |
sr=16000 | |
print(sr) | |
print(audio_float32.shape) | |
# vad_output = vad_iterator(torch.from_numpy(audio_float32)) | |
vad_count=0 | |
for i in range(int(len(y)/960)): | |
vad = webrtcvad.Vad() | |
vad.set_mode(3) | |
if (vad.is_speech(y[i*960:(i+1)*960].tobytes(), orig_sr)): | |
vad_count+=1 | |
print(vad_count) | |
if vad_output is None and vad_count>12: | |
vad_curr=True | |
if vad_output is None: | |
vad_output=[torch.from_numpy(audio_float32)] | |
else: | |
vad_output.append(torch.from_numpy(audio_float32)) | |
elif vad_output is not None and vad_count>10: | |
vad_curr=True | |
if vad_output is None: | |
vad_output=[torch.from_numpy(audio_float32)] | |
else: | |
vad_output.append(torch.from_numpy(audio_float32)) | |
else: | |
vad_curr=False | |
if vad_output is not None and vad_curr==False: | |
print("VAD: end of speech detected") | |
array = torch.cat(vad_output).cpu().numpy() | |
duration_ms = len(array) / sr * 1000 | |
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)): | |
# input_features = ASR_processor( | |
# array, sampling_rate=16000, return_tensors="pt" | |
# ).input_features | |
# print(input_features) | |
# input_features = input_features.to("cpu", dtype=getattr(torch, "float16")) | |
# pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en") | |
# print(pred_ids) | |
# prompt = ASR_processor.batch_decode( | |
# pred_ids, skip_special_tokens=True, decode_with_timestamps=False | |
# )[0] | |
print(len(array)) | |
array = librosa.util.fix_length(array, size=(16000 * 30)) | |
print(len(array)) | |
start_time = time.time() | |
prompt=" ".join(s2t(array)[0][0].split()[1:]) | |
vad_output = None | |
if len(prompt.strip().split())<2: | |
text_str1=text_str | |
return stream, asr_output_str, text_str1, audio_output | |
# prompt=transcriber({"sampling_rate": sr, "raw": array})["text"] | |
print(len(prompt.strip().split())) | |
asr_output_str=prompt | |
yield (stream,asr_output_str,text_str, audio_output) | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
# prompt=ASR_model.transcribe(array)["text"].strip() | |
chat.append({"role": user_role, "content": prompt}) | |
chat_messages = chat.to_list() | |
LM_pipe( | |
chat_messages, | |
max_new_tokens=256, | |
min_new_tokens=0, | |
temperature=0.0, | |
do_sample=False, | |
streamer=streamer, | |
return_full_text=False, | |
) | |
output="" | |
curr_output = "" | |
text_str = "" | |
for t in streamer: | |
output += t | |
curr_output += t | |
sentences=sent_tokenize(curr_output) | |
if len(sentences)>1: | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
print(sentences[0]) | |
with torch.no_grad(): | |
audio_chunk = text2speech(sentences[0])["wav"].view(-1).cpu().numpy() | |
# audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True) | |
audio_chunk = (audio_chunk * 32768).astype(np.int16) | |
print(text2speech.fs) | |
audio_output=(text2speech.fs, audio_chunk) | |
print("okk") | |
# print(audio_chunk) | |
# print(audio_chunk.shape) | |
text_str=text_str+sentences[0] | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
yield (stream,asr_output_str,text_str, audio_output) | |
time.sleep((len(audio_chunk)/text2speech.fs)-0.2) | |
curr_output = t | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
print(curr_output) | |
with torch.no_grad(): | |
audio_chunk = text2speech(curr_output)["wav"].view(-1).cpu().numpy() | |
# audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True) | |
audio_chunk = (audio_chunk * 32768).astype(np.int16) | |
print(text2speech.fs) | |
audio_output=(text2speech.fs, audio_chunk) | |
print("okk") | |
# print(audio_chunk) | |
text_str=output | |
print(audio_chunk.shape) | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
# yield (stream,output, audio_output) | |
# time.sleep((len(audio_chunk)/text2speech.fs)-0.2) | |
curr_output = "" | |
generated_text = output | |
# torch.mps.empty_cache() | |
chat.append({"role": "assistant", "content": generated_text}) | |
# text_str=generated_text | |
# import pdb;pdb.set_trace() | |
# with torch.no_grad(): | |
# audio_chunk = text2speech(text_str)["wav"].view(-1).cpu().numpy() | |
# # audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True) | |
# audio_chunk = (audio_chunk * 32768).astype(np.int16) | |
# print(text2speech.fs) | |
# audio_output=(text2speech.fs, audio_chunk) | |
# else: | |
# audio_output=None | |
text_str1=text_str | |
yield (stream,asr_output_str,text_str1,audio_output) | |
demo = gr.Interface( | |
transcribe, | |
["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))], | |
["state", gr.Textbox(label="ASR output"),gr.Textbox(label="LLM output"), gr.Audio(label="TTS Output", autoplay=True,visible=True,)], | |
live=True, | |
) | |
# with demo: | |
# start_button = gr.Button("Record Screen 🔴") | |
# video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True) | |
# def toggle_button_label(returned_string): | |
# if returned_string.startswith("Record"): | |
# return gr.Button(value="Stop Recording ⚪"), None | |
# else: | |
# try: | |
# temp_filename = save_base64_video(returned_string) | |
# except Exception as e: | |
# return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}') | |
# return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True, | |
# show_share_button=True) | |
# start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js) | |
demo.launch("share=True") | |