import torch import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import soundfile as sf import numpy as np from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub from fairseq.models.text_to_speech.hub_interface import TTSHubInterface import IPython.display as ipd # We still need this if running in a notebook # --- Whisper (ASR) Setup --- ASR_MODEL_NAME = "openai/whisper-large-v2" asr_device = "cuda" if torch.cuda.is_available() else "cpu" asr_pipe = pipeline( task="automatic-speech-recognition", model=ASR_MODEL_NAME, chunk_length_s=30, device=asr_device, ) all_special_ids = asr_pipe.tokenizer.all_special_ids transcribe_token_id = all_special_ids[-5] translate_token_id = all_special_ids[-6] # --- FastSpeech2 (TTS) Setup - Using fairseq --- TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech" tts_device = "cuda" if torch.cuda.is_available() else "cpu" # Load the fairseq model, config, and task. tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub( TTS_MODEL_NAME, arg_overrides={"vocoder": "hifigan", "fp16": False} ) tts_model = tts_models[0] TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg) tts_generator = tts_task.build_generator(tts_model, tts_cfg) # Move the fairseq model to the correct device. tts_model.to(tts_device) tts_model.eval() # Put the model in evaluation mode # --- Vicuna (LLM) Setup --- VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna vicuna_device = "cuda" if torch.cuda.is_available() else "cpu" vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME) vicuna_model = AutoModelForCausalLM.from_pretrained( VICUNA_MODEL_NAME, load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", ) # --- ASR Function --- def transcribe_audio(microphone, state, task="transcribe"): if microphone is None: return state, state asr_pipe.model.config.forced_decoder_ids = [ [2, transcribe_token_id if task == "transcribe" else translate_token_id] ] text = asr_pipe(microphone)["text"] # --- VICUNA INTEGRATION --- system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9). You answer questions clearly and simply, using age-appropriate language. You are also a little bit silly and like to make jokes.""" prompt = f"{system_prompt}\nUser: {text}" with torch.no_grad(): vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device) vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128) vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) vicuna_response = vicuna_response.replace(prompt, "").strip() updated_state = state + "\n" + vicuna_response return updated_state, updated_state # --- TTS Function (Modified for fairseq) --- def synthesize_speech(text): try: sample = TTSHubInterface.get_model_input(tts_task, text) # Move input tensors to the correct device if torch.cuda.is_available(): sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()} else: sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()} wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample) wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array return (rate, wav_numpy) # Return rate and NumPy array except Exception as e: print(e) return (None, None) # --- Gradio Interface --- with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo: gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna") gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!") with gr.Tab("Transcribe & Synthesize"): mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here") transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") audio_output = gr.Audio(label="Synthesized Speech", type="numpy") transcription_state = gr.State(value="") mic_input.change( fn=transcribe_audio, inputs=[mic_input, transcription_state], outputs=[transcription_output, transcription_state] ).then( fn=synthesize_speech, inputs=transcription_output, outputs=audio_output ) demo.launch(enable_queue=True, share=False)