Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torchaudio
|
3 |
+
import io
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import time # Import the time module
|
6 |
+
|
7 |
+
from audio_recorder_streamlit import audio_recorder
|
8 |
+
from trainer import SpeechLLMLightning
|
9 |
+
import re
|
10 |
+
import json
|
11 |
+
|
12 |
+
import whisper
|
13 |
+
import re
|
14 |
+
from transformers import AutoProcessor
|
15 |
+
|
16 |
+
# Function to load the model and tokenizer
|
17 |
+
def plot_mel_spectrogram(mel_spec):
|
18 |
+
plt.figure(figsize=(10, 4))
|
19 |
+
plt.imshow(mel_spec.squeeze().cpu().numpy(), aspect='auto', origin='lower')
|
20 |
+
plt.colorbar(format='%+2.0f dB')
|
21 |
+
plt.title('Mel Spectrogram')
|
22 |
+
plt.tight_layout()
|
23 |
+
st.pyplot(plt)
|
24 |
+
|
25 |
+
def get_or_load_model():
|
26 |
+
if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
|
27 |
+
ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
|
28 |
+
model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
|
29 |
+
tokenizer = model.llm_tokenizer
|
30 |
+
model.eval()
|
31 |
+
model.freeze()
|
32 |
+
model.to('cuda')
|
33 |
+
st.session_state.model = model
|
34 |
+
st.session_state.tokenizer = tokenizer
|
35 |
+
|
36 |
+
st.session_state.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
|
37 |
+
return st.session_state.model, st.session_state.tokenizer, st.session_state.processor
|
38 |
+
|
39 |
+
def extract_dictionary(input_string):
|
40 |
+
json_str_match = re.search(r'\{.*\}', input_string)
|
41 |
+
if not json_str_match:
|
42 |
+
print(input_string)
|
43 |
+
return "No valid JSON found."
|
44 |
+
|
45 |
+
json_str = json_str_match.group(0)
|
46 |
+
|
47 |
+
json_str = re.sub(r'(?<=\{|\,)\s*([^\"{}\[\]\s]+)\s*:', r'"\1":', json_str) # Fix unquoted keys
|
48 |
+
json_str = re.sub(r',\s*([\}\]])', r'\1', json_str) # Remove trailing commas
|
49 |
+
|
50 |
+
try:
|
51 |
+
data_dict = json.loads(json_str)
|
52 |
+
return data_dict
|
53 |
+
except json.JSONDecodeError as e:
|
54 |
+
return f"Error parsing JSON: {str(e)}"
|
55 |
+
|
56 |
+
pre_speech_prompt = '''Instruction:
|
57 |
+
Give me the following information about the speech [Transcript, Gender, Age, Emotion, Accent]
|
58 |
+
|
59 |
+
Input:
|
60 |
+
<speech>'''
|
61 |
+
|
62 |
+
post_speech_prompt = f'''</speech>
|
63 |
+
|
64 |
+
Output:'''
|
65 |
+
|
66 |
+
# Function to generate a response from the model
|
67 |
+
def generate_response(mel, pre_speech_prompt, post_speech_prompt, model, tokenizer):
|
68 |
+
output_prompt = '\n<s>'
|
69 |
+
|
70 |
+
pre_tokenized_ids = tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
|
71 |
+
post_tokenized_ids = tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
|
72 |
+
output_tokenized_ids = tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
|
73 |
+
|
74 |
+
combined_embeds, atts, label_ids = model.encode(mel.cuda(), pre_tokenized_ids.cuda(), post_tokenized_ids.cuda(), output_tokenized_ids.cuda())
|
75 |
+
|
76 |
+
start_time = time.time() # Record start time
|
77 |
+
out = model.llm_model.generate(
|
78 |
+
inputs_embeds=combined_embeds,
|
79 |
+
max_new_tokens=2000,
|
80 |
+
).cpu().tolist()[0]
|
81 |
+
end_time = time.time() # Record end time
|
82 |
+
|
83 |
+
latency = (end_time - start_time) * 1000 # Calculate latency in milliseconds
|
84 |
+
|
85 |
+
output_text = tokenizer.decode(out, skip_special_tokens=True)
|
86 |
+
return output_text, latency
|
87 |
+
|
88 |
+
def extract_prediction_values(self, input_string):
|
89 |
+
json_str_match = re.search(r'<s>\s*\{.*?\}\s*</s>', input_string)
|
90 |
+
try:
|
91 |
+
json_str = json_str_match.group(0)
|
92 |
+
except:
|
93 |
+
json_str = '{}'
|
94 |
+
return self.extract_dictionary(json_str)
|
95 |
+
|
96 |
+
# Load model and tokenizer once and store them in session_state
|
97 |
+
model, tokenizer, processor = get_or_load_model()
|
98 |
+
|
99 |
+
# Streamlit UI components
|
100 |
+
st.title("Multi-Modal Speech LLM")
|
101 |
+
st.write("Record an audio file to get its transcription and other metadata.")
|
102 |
+
|
103 |
+
pre_prompt = st.text_area("Pre Speech Prompt:", value=pre_speech_prompt, height=150)
|
104 |
+
post_prompt = st.text_area("Post Speech Prompt:", value=post_speech_prompt, height=100)
|
105 |
+
|
106 |
+
# Audio recording
|
107 |
+
audio_data = audio_recorder(sample_rate=16000)
|
108 |
+
|
109 |
+
# Transcription process
|
110 |
+
if audio_data is not None:
|
111 |
+
with st.spinner('Transcribing...'):
|
112 |
+
try:
|
113 |
+
# Load audio data into a tensor
|
114 |
+
audio_buffer = io.BytesIO(audio_data)
|
115 |
+
st.audio(audio_data, format='audio/wav', start_time=0)
|
116 |
+
wav_tensor, sample_rate = torchaudio.load(audio_buffer)
|
117 |
+
wav_tensor = wav_tensor.to('cuda')
|
118 |
+
audio = wav_tensor.mean(0)
|
119 |
+
mel = whisper.log_mel_spectrogram(audio)
|
120 |
+
plot_mel_spectrogram(mel)
|
121 |
+
|
122 |
+
audio = processor(audio.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
|
123 |
+
|
124 |
+
# Process audio to get transcription
|
125 |
+
prediction, latency = generate_response(audio.cuda(), pre_prompt, post_prompt, model, tokenizer)
|
126 |
+
pred_dict = extract_dictionary(prediction)
|
127 |
+
|
128 |
+
user_utterance = '<user>' + pred_dict['Transcript']
|
129 |
+
|
130 |
+
# Display the transcription and latency
|
131 |
+
st.success('Transcription Complete')
|
132 |
+
st.text_area("LLM Output:", value=pred_dict, height=200, max_chars=None)
|
133 |
+
st.write(f"Latency in CPU: {latency:.2f} ms")
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
st.error(f"An error occurred during transcription: {e}")
|