import gradio as gr import numpy as np import json import librosa import os import soundfile as sf import tempfile import uuid import transformers import torch import time import spaces from nemo.collections.asr.models import ASRModel from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import pipeline #### Variables ### # Set an environment variable HF_TOKEN = os.environ.get("HF_TOKEN", None) SAMPLE_RATE = 16000 # Hz MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this DESCRIPTION = '''

MyAlexa: Voice Chat Assistant

MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response.

This space uses NVIDIA Canary 1B for Automatic Speech-to-text Recognition (ASR), Meta Llama 3 8B Insruct for the large language model (LLM) and VITS-ljs by Kakao Enterprise for text to speech (TTS).

This demo accepts audio inputs not more than 40 seconds long. Transcription and responses are limited to the English language.

The LLM max_new_tokens, temperature and top_p are set to 512, 0.6 and 0.9 respectively

''' PLACEHOLDER = """

What's on your mind?

""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ### ASR model ### canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device) canary_model.eval() # make sure beam size always 1 for consistency canary_model.change_decoding_strategy(None) decoding_cfg = canary_model.cfg.decoding decoding_cfg.beam.beam_size = 1 canary_model.change_decoding_strategy(decoding_cfg) ### LLM model ### # Load the tokenizer and model llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0") if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token terminators = [ llm_tokenizer.eos_token_id, llm_tokenizer.convert_tokens_to_ids("<|eot_id|>") ] ### TTS model ### pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device) ### Start of functions ### def convert_audio(audio_filepath, tmpdir, utt_id): """ Convert all files to monochannel 16 kHz wav files. Do not convert and raise error if audio is too long. Returns output filename and duration. """ data, sr = librosa.load(audio_filepath, sr=None, mono=True) duration = librosa.get_duration(y=data, sr=sr) if duration > MAX_AUDIO_SECONDS: raise gr.Error( f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. " "If you wish, you may trim the audio using the Audio viewer in Step 1 " "(click on the scissors icon to start trimming audio)." ) if sr != SAMPLE_RATE: data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) out_filename = os.path.join(tmpdir, utt_id + '.wav') # save output audio sf.write(out_filename, data, SAMPLE_RATE) return out_filename, duration def transcribe(audio_filepath): """ Transcribes a converted audio file using the asr model. Set to the english language with punctuations. Returns the transcribed text as a string. """ if audio_filepath is None: raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") utt_id = uuid.uuid4() with tempfile.TemporaryDirectory() as tmpdir: converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) # make manifest file and save manifest_data = { "audio_filepath": converted_audio_filepath, "source_lang": "en", "target_lang": "en", "taskname": "asr", "pnc": "yes", "answer": "predict", "duration": str(duration), } manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') with open(manifest_filepath, 'w') as fout: line = json.dumps(manifest_data) fout.write(line + '\n') # call transcribe, passing in manifest filepath output_text = canary_model.transcribe(manifest_filepath)[0] return output_text def add_message(history, message): """ Adds the input message in the chatbot. Returns the updated chatbot. """ history.append((message, None)) return history def bot(history, message): """ Gets the bot's response and adds it in the chatbot. Returns the appended chatbot. """ response = bot_response(message, history) lines = response.split("\n") complete_lines = '\n'.join(lines[2:]) answer = "" for character in complete_lines: answer += character new_tuple = list(history[-1]) new_tuple[1] = answer history[-1] = tuple(new_tuple) time.sleep(0.01) yield history @spaces.GPU() def bot_response(message, history): """ Generates a streaming response using the llm model. Set max_new_tokens = 512, temperature=0.6, and top_p=0.9 Returns the generated response in string format. """ conversation = [] for user, assistant in history: conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = llm_tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device) outputs = llama3_model.generate( input_ids, max_new_tokens = 512, eos_token_id = terminators, do_sample=True, temperature=0.6, top_p=0.9, pad_token_id=llm_tokenizer.pad_token_id, ) out = outputs[0][input_ids.shape[-1]:] return llm_tokenizer.decode(out, skip_special_tokens=True) @spaces.GPU() def voice_player(history): """ Plays the generated response using the tts model. Returns the audio player with the generated response. """ _, text = history[-1] text = text.replace("*", "") # Temp. fix: For the tts to not read the asterisk of bold text voice = pipe(text) voice = gr.Audio(value = ( voice["sampling_rate"], voice["audio"].squeeze()), type="numpy", autoplay=True, label="MyAlexa Response", show_label=True, visible=True) return voice ### End of functions ### ### Interface using Blocks### with gr.Blocks( title="MyAlexa", css=""" textarea { font-size: 18px;} """, theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md ) ) as demo: gr.HTML(DESCRIPTION) chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, placeholder=PLACEHOLDER, label='MyAlexa' ) with gr.Row(): with gr.Column(): gr.HTML( "

Step 1: Upload an audio file or record with your microphone.

" ) audio_file = gr.Audio( sources=["microphone", "upload"], type="filepath" ) with gr.Column(): gr.HTML("

Step 2: Submit your recorded or uploaded audio as input and wait for MyAlexa's response.

") submit_button = gr.Button( value="Submit audio", variant="primary" ) chat_input = gr.Textbox( # Shows the transcribed text label="Transcribed text:", interactive=False, placeholder="Transcribed text will appear here.", elem_id="chat_input", visible=False # set to True to see processing time of asr transcription ) gr.HTML("

[Optional]: Replay MyAlexa's voice response.

") out_audio = gr.Audio( # Shows an audio player for the generated response value = None, label="Response Audio Player", show_label=True, visible=False # set to True to see processing time of the first tts audio generation ) chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot") bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response_in_chatbot") voice_msg = bot_msg.then(voice_player, chatbot, out_audio, api_name="bot_response_voice_player") submit_button.click( fn=transcribe, inputs = [audio_file], outputs = [chat_input] ) ### Queue and launch the demo ### demo.queue() if __name__ == "__main__": demo.launch()