import streamlit as st import torchaudio import io import matplotlib.pyplot as plt import time # Import the time module from audio_recorder_streamlit import audio_recorder from trainer import SpeechLLMLightning import re import json import whisper import re import torch from transformers import AutoProcessor if torch.cuda.is_available(): # Set the device to CUDA device = torch.device("cuda") else: # Set the device to CPU device = torch.device("cpu") # Function to load the model and tokenizer def plot_mel_spectrogram(mel_spec): plt.figure(figsize=(10, 4)) plt.imshow(mel_spec.squeeze().cpu().numpy(), aspect='auto', origin='lower') plt.colorbar(format='%+2.0f dB') plt.title('Mel Spectrogram') plt.tight_layout() st.pyplot(plt) def get_or_load_model(): if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state: ckpt_path = "checkpoints/pretrained_checkpoint.ckpt" model = SpeechLLMLightning.load_from_checkpoint(ckpt_path, quantize=False) tokenizer = model.llm_tokenizer model.eval() model.freeze() model.to(device) st.session_state.model = model st.session_state.tokenizer = tokenizer st.session_state.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") return st.session_state.model, st.session_state.tokenizer, st.session_state.processor def extract_dictionary(input_string): json_str_match = re.search(r'\{.*\}', input_string) if not json_str_match: print(input_string) return "No valid JSON found." json_str = json_str_match.group(0) json_str = re.sub(r'(?<=\{|\,)\s*([^\"{}\[\]\s]+)\s*:', r'"\1":', json_str) # Fix unquoted keys json_str = re.sub(r',\s*([\}\]])', r'\1', json_str) # Remove trailing commas try: data_dict = json.loads(json_str) return data_dict except json.JSONDecodeError as e: return f"Error parsing JSON: {str(e)}" pre_speech_prompt = '''Instruction: Give me the following information about the speech [Transcript, Gender, Age, Emotion, Accent] Input: ''' post_speech_prompt = f''' Output:''' # Function to generate a response from the model def generate_response(mel, pre_speech_prompt, post_speech_prompt, model, tokenizer): output_prompt = '\n' pre_tokenized_ids = tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] post_tokenized_ids = tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] output_tokenized_ids = tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] combined_embeds, atts, label_ids = model.encode(mel.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device)) start_time = time.time() # Record start time out = model.llm_model.generate( inputs_embeds=combined_embeds, max_new_tokens=2000, ).cpu().tolist()[0] end_time = time.time() # Record end time latency = (end_time - start_time) * 1000 # Calculate latency in milliseconds output_text = tokenizer.decode(out, skip_special_tokens=True) return output_text, latency def extract_prediction_values(self, input_string): json_str_match = re.search(r'\s*\{.*?\}\s*', input_string) try: json_str = json_str_match.group(0) except: json_str = '{}' return self.extract_dictionary(json_str) # Load model and tokenizer once and store them in session_state model, tokenizer, processor = get_or_load_model() # Streamlit UI components st.title("Multi-Modal Speech LLM") st.write("Record an audio file to get its transcription and other metadata.") pre_prompt = st.text_area("Pre Speech Prompt:", value=pre_speech_prompt, height=150) post_prompt = st.text_area("Post Speech Prompt:", value=post_speech_prompt, height=100) # Audio recording audio_data = audio_recorder(sample_rate=16000) # Transcription process if audio_data is not None: with st.spinner('Transcribing...'): try: # Load audio data into a tensor audio_buffer = io.BytesIO(audio_data) st.audio(audio_data, format='audio/wav', start_time=0) wav_tensor, sample_rate = torchaudio.load(audio_buffer) wav_tensor = wav_tensor.to(device) audio = wav_tensor.mean(0) mel = whisper.log_mel_spectrogram(audio) plot_mel_spectrogram(mel) audio = processor(audio.squeeze(), return_tensors="pt", sampling_rate=16000).input_values # Process audio to get transcription prediction, latency = generate_response(audio.to(device), pre_prompt, post_prompt, model, tokenizer) pred_dict = extract_dictionary(prediction) user_utterance = '' + pred_dict['Transcript'] # Display the transcription and latency st.success('Transcription Complete') st.text_area("LLM Output:", value=pred_dict, height=200, max_chars=None) st.write(f"Latency in CPU: {latency:.2f} ms") except Exception as e: st.error(f"An error occurred during transcription: {e}")