Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import traceback | |
from dataclasses import dataclass, field | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import spaces | |
import torch | |
import xxhash | |
from datasets import Audio | |
from transformers import AutoModel | |
import io | |
from pydub import AudioSegment | |
import tempfile | |
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps | |
diva_model = AutoModel.from_pretrained( | |
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True | |
) | |
resampler = Audio(sampling_rate=16_000) | |
def diva_audio(audio_input, do_sample=False, temperature=0.001): | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
yield from diva_model.generate_stream( | |
a["array"], None, do_sample=do_sample, max_new_tokens=256 | |
) | |
def run_vad(ori_audio, sr): | |
_st = time.time() | |
try: | |
audio = ori_audio | |
audio = audio.astype(np.float32) / 32768.0 | |
sampling_rate = 16000 | |
if sr != sampling_rate: | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) | |
vad_parameters = {} | |
vad_parameters = VadOptions(**vad_parameters) | |
speech_chunks = get_speech_timestamps(audio, vad_parameters) | |
audio = collect_chunks(audio, speech_chunks) | |
duration_after_vad = audio.shape[0] / sampling_rate | |
if sr != sampling_rate: | |
# resample to original sampling rate | |
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) | |
else: | |
vad_audio = audio | |
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) | |
vad_audio_bytes = vad_audio.tobytes() | |
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) | |
except Exception as e: | |
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" | |
print(msg) | |
return -1, ori_audio, round(time.time() - _st, 4) | |
def warm_up(): | |
frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each | |
dur, frames, tcost = run_vad(frames, 16000) | |
print(f"warm up done, time_cost: {tcost:.3f} s") | |
warm_up() | |
class AppState: | |
stream: np.ndarray | None = None | |
sampling_rate: int = 0 | |
pause_detected: bool = False | |
started_talking: bool = False | |
stopped: bool = False | |
conversation: list = field(default_factory=list) | |
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: | |
"""Take in the stream, determine if a pause happened""" | |
temp_audio = audio | |
dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate) | |
duration = len(audio) / sampling_rate | |
if dur_vad > 0.5 and not state.started_talking: | |
print("started talking") | |
state.started_talking = True | |
return False | |
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") | |
return (duration - dur_vad) > 1 | |
def process_audio(audio: tuple, state: AppState): | |
if state.stream is None: | |
state.stream = audio[1] | |
state.sampling_rate = audio[0] | |
else: | |
state.stream = np.concatenate((state.stream, audio[1])) | |
pause_detected = determine_pause(state.stream, state.sampling_rate, state) | |
state.pause_detected = pause_detected | |
if state.pause_detected and state.started_talking: | |
return gr.Audio(recording=False), state | |
return None, state | |
def response(state: AppState): | |
if not state.pause_detected and not state.started_talking: | |
return AppState() | |
file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav" | |
sf.write(file_name, state.stream, state.sampling_rate, format="wav") | |
state.conversation.append( | |
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}} | |
) | |
start = False | |
for resp in diva_audio((state.sampling_rate, state.stream)): | |
if not start: | |
state.conversation.append({"role": "assistant", "content": resp}) | |
start = True | |
else: | |
state.conversation[-1]["content"] = resp | |
yield state, state.conversation | |
yield AppState(conversation=state.conversation), state.conversation | |
def start_recording_user(state: AppState): | |
if not state.stopped: | |
return gr.Audio(recording=True) | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.Audio( | |
label="Input Audio", sources="microphone", type="numpy" | |
) | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
state = gr.State(value=AppState()) | |
stream = input_audio.stream( | |
process_audio, | |
[input_audio, state], | |
[input_audio, state], | |
stream_every=0.50, | |
time_limit=30, | |
) | |
respond = input_audio.stop_recording(response, [state], [state, chatbot]) | |
respond.then(start_recording_user, [state], [input_audio]) | |
cancel = gr.Button("Stop Conversation", variant="stop") | |
cancel.click( | |
lambda: (AppState(stopped=True), gr.Audio(recording=False)), | |
None, | |
[state, input_audio], | |
cancels=[respond, stream], | |
) | |
demo.launch(share=True) | |