diva-audio-chat / app.py
Helw150
Multi turn
3268a02
raw
history blame
5.97 kB
import time
import traceback
from dataclasses import dataclass, field
import gradio as gr
import librosa
import numpy as np
import soundfile as sf
import spaces
import torch
import xxhash
from datasets import Audio
from transformers import AutoModel
import io
from pydub import AudioSegment
import tempfile
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
diva_model = AutoModel.from_pretrained(
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True
)
resampler = Audio(sampling_rate=16_000)
@spaces.GPU
@torch.no_grad
def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
sr, y = audio_input
x = xxhash.xxh32(bytes(y)).hexdigest()
y = y.astype(np.float32)
y /= np.max(np.abs(y))
a = resampler.decode_example(
resampler.encode_example({"array": y, "sampling_rate": sr})
)
yield from diva_model.generate_stream(
a["array"],
None,
do_sample=do_sample,
max_new_tokens=256,
init_outputs=prev_outs,
return_outputs=True,
)
def run_vad(ori_audio, sr):
_st = time.time()
try:
audio = ori_audio
audio = audio.astype(np.float32) / 32768.0
sampling_rate = 16000
if sr != sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
vad_parameters = {}
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
if sr != sampling_rate:
# resample to original sampling rate
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
else:
vad_audio = audio
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
vad_audio_bytes = vad_audio.tobytes()
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
except Exception as e:
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
print(msg)
return -1, ori_audio, round(time.time() - _st, 4)
def warm_up():
frames = np.ones(2048) # 1024 frames of 2 bytes each
dur, frames, tcost = run_vad(frames, 16000)
print(f"warm up done, time_cost: {tcost:.3f} s")
warm_up()
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
stopped: bool = False
conversation: list = field(default_factory=list)
model_outs: any = None
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
"""Take in the stream, determine if a pause happened"""
temp_audio = audio
dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
duration = len(audio) / sampling_rate
if dur_vad > 0.5 and not state.started_talking:
print("started talking")
state.started_talking = True
return False
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
return (duration - dur_vad) > 1
def process_audio(audio: tuple, state: AppState):
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
pause_detected = determine_pause(state.stream, state.sampling_rate, state)
state.pause_detected = pause_detected
if state.pause_detected and state.started_talking:
return gr.Audio(recording=False), state
return None, state
def response(state: AppState):
if not state.pause_detected and not state.started_talking:
return AppState()
file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
sf.write(file_name, state.stream, state.sampling_rate, format="wav")
state.conversation.append(
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}}
)
start = False
for resp, outs in diva_audio(
(state.sampling_rate, state.stream), prev_outs=state.model_outs
):
if not start:
state.conversation.append({"role": "assistant", "content": resp})
start = True
else:
state.conversation[-1]["content"] = resp
yield state, state.conversation
yield AppState(conversation=state.conversation, model_outs=outs), state.conversation
def start_recording_user(state: AppState):
if not state.stopped:
return gr.Audio(recording=True)
theme = gr.themes.Soft(
primary_hue=gr.themes.Color(
c100="#82000019",
c200="#82000033",
c300="#8200004c",
c400="#82000066",
c50="#8200007f",
c500="#8200007f",
c600="#82000099",
c700="#820000b2",
c800="#820000cc",
c900="#820000e5",
c950="#820000f2",
),
secondary_hue="rose",
neutral_hue="stone",
)
with gr.Blocks(theme=theme) as demo:
with gr.Row():
with gr.Column():
input_audio = gr.Audio(
label="Input Audio", sources="microphone", type="numpy"
)
with gr.Column():
chatbot = gr.Chatbot(label="Conversation", type="messages")
state = gr.State(value=AppState())
stream = input_audio.stream(
process_audio,
[input_audio, state],
[input_audio, state],
stream_every=0.50,
time_limit=30,
)
respond = input_audio.stop_recording(response, [state], [state, chatbot])
respond.then(start_recording_user, [state], [input_audio])
cancel = gr.Button("Stop Conversation", variant="stop")
cancel.click(
lambda: (AppState(stopped=True), gr.Audio(recording=False)),
None,
[state, input_audio],
cancels=[respond, stream],
)
demo.launch(share=True)