Spaces:
Runtime error
Runtime error
import json | |
import os.path | |
import tempfile | |
import sys | |
import re | |
import uuid | |
import requests | |
from argparse import ArgumentParser | |
import torchaudio | |
from transformers import WhisperFeatureExtractor, AutoTokenizer | |
from speech_tokenizer.modeling_whisper import WhisperVQEncoder | |
sys.path.insert(0, "./cosyvoice") | |
sys.path.insert(0, "./third_party/Matcha-TTS") | |
from speech_tokenizer.utils import extract_speech_token | |
import gradio as gr | |
import torch | |
audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>") | |
from flow_inference import AudioDecoder | |
from audio_process import AudioStreamProcessor | |
use_local_interface = True | |
if use_local_interface : | |
from model_server import ModelWorker | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default="8888") | |
parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder") | |
parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b") | |
parser.add_argument("--tokenizer-path", type= str, default="THUDM/glm-4-voice-tokenizer") | |
args = parser.parse_args() | |
flow_config = os.path.join(args.flow_path, "config.yaml") | |
flow_checkpoint = os.path.join(args.flow_path, 'flow.pt') | |
hift_checkpoint = os.path.join(args.flow_path, 'hift.pt') | |
glm_tokenizer = None | |
device = "cpu" | |
audio_decoder: AudioDecoder = None | |
whisper_model, feature_extractor = None, None | |
def initialize_fn(): | |
global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer | |
if audio_decoder is not None: | |
return | |
# GLM | |
glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) | |
# Flow & Hift | |
audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint, | |
hift_ckpt_path=hift_checkpoint, | |
device=device) | |
# Speech tokenizer | |
whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device) | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path) | |
global use_local_interface, worker | |
if use_local_interface : | |
model_path0 = 'THUDM/glm-4-voice-9b' | |
# dtype = 'bfloat16' | |
device0 = 'cpu' | |
worker = ModelWorker(model_path0,device0) | |
def clear_fn(): | |
return [], [], '', '', '', None, None | |
def inference_fn( | |
temperature: float, | |
top_p: float, | |
max_new_token: int, | |
input_mode, | |
audio_path: str | None, | |
input_text: str | None, | |
history: list[dict], | |
previous_input_tokens: str, | |
previous_completion_tokens: str, | |
): | |
if input_mode == "audio": | |
assert audio_path is not None | |
history.append({"role": "user", "content": {"path": audio_path}}) | |
audio_tokens = extract_speech_token( | |
whisper_model, feature_extractor, [audio_path] | |
)[0] | |
if len(audio_tokens) == 0: | |
raise gr.Error("No audio tokens extracted") | |
audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens]) | |
audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>" | |
user_input = audio_tokens | |
system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. " | |
else: | |
assert input_text is not None | |
history.append({"role": "user", "content": input_text}) | |
user_input = input_text | |
system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens." | |
# Gather history | |
inputs = previous_input_tokens + previous_completion_tokens | |
inputs = inputs.strip() | |
if "<|system|>" not in inputs: | |
inputs += f"<|system|>\n{system_prompt}" | |
inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n" | |
with torch.no_grad(): | |
response = requests.post( | |
"http://localhost:10000/generate_stream", | |
data=json.dumps({ | |
"prompt": inputs, | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_new_tokens": max_new_token, | |
}), | |
stream=True | |
) | |
text_tokens, audio_tokens = [], [] | |
audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>') | |
end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>') | |
complete_tokens = [] | |
prompt_speech_feat = torch.zeros(1, 0, 80).to(device) | |
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device) | |
this_uuid = str(uuid.uuid4()) | |
tts_speechs = [] | |
tts_mels = [] | |
prev_mel = None | |
is_finalize = False | |
block_size_list = [25,50,100,150,200] | |
block_size_idx = 0 | |
block_size = block_size_list[block_size_idx] | |
audio_processor = AudioStreamProcessor() | |
for chunk in response.iter_lines(): | |
token_id = json.loads(chunk)["token_id"] | |
if token_id == end_token_id: | |
is_finalize = True | |
if len(audio_tokens) >= block_size or (is_finalize and audio_tokens): | |
if block_size_idx < len(block_size_list) - 1: | |
block_size_idx += 1 | |
block_size = block_size_list[block_size_idx] | |
tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0) | |
if prev_mel is not None: | |
prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2) | |
tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid, | |
prompt_token=flow_prompt_speech_token.to(device), | |
prompt_feat=prompt_speech_feat.to(device), | |
finalize=is_finalize) | |
prev_mel = tts_mel | |
audio_bytes = audio_processor.process(tts_speech.clone().cpu().numpy()[0], last=is_finalize) | |
tts_speechs.append(tts_speech.squeeze()) | |
tts_mels.append(tts_mel) | |
if audio_bytes: | |
yield history, inputs, '', '', audio_bytes, None | |
flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1) | |
audio_tokens = [] | |
if not is_finalize: | |
complete_tokens.append(token_id) | |
if token_id >= audio_offset: | |
audio_tokens.append(token_id - audio_offset) | |
else: | |
text_tokens.append(token_id) | |
tts_speech = torch.cat(tts_speechs, dim=-1).cpu() | |
complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav") | |
history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}}) | |
history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)}) | |
yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy()) | |
def update_input_interface(input_mode): | |
if input_mode == "audio": | |
return [gr.update(visible=True), gr.update(visible=False)] | |
else: | |
return [gr.update(visible=False), gr.update(visible=True)] | |
# Create the Gradio interface | |
with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo: | |
with gr.Row(): | |
temperature = gr.Number( | |
label="Temperature", | |
value=0.2 | |
) | |
top_p = gr.Number( | |
label="Top p", | |
value=0.8 | |
) | |
max_new_token = gr.Number( | |
label="Max new tokens", | |
value=2000, | |
) | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
bubble_full_width=False, | |
type="messages", | |
scale=1, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio") | |
audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True) | |
text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False) | |
with gr.Column(): | |
submit_btn = gr.Button("Submit") | |
reset_btn = gr.Button("Clear") | |
output_audio = gr.Audio(label="Play", streaming=True, | |
autoplay=True, show_download_button=False) | |
complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True) | |
gr.Markdown("""## Debug Info""") | |
with gr.Row(): | |
input_tokens = gr.Textbox( | |
label=f"Input Tokens", | |
interactive=False, | |
) | |
completion_tokens = gr.Textbox( | |
label=f"Completion Tokens", | |
interactive=False, | |
) | |
detailed_error = gr.Textbox( | |
label=f"Detailed Error", | |
interactive=False, | |
) | |
history_state = gr.State([]) | |
respond = submit_btn.click( | |
inference_fn, | |
inputs=[ | |
temperature, | |
top_p, | |
max_new_token, | |
input_mode, | |
audio, | |
text_input, | |
history_state, | |
input_tokens, | |
completion_tokens, | |
], | |
outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio] | |
) | |
respond.then(lambda s: s, [history_state], chatbot) | |
reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]) | |
input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input]) | |
initialize_fn() | |
# Launch the interface | |
demo.launch( | |
server_port=args.port, | |
server_name=args.host | |
) | |