diva-audio-chat / app.py
Helw150
Spaces GPU Move?
4898006
raw
history blame
5.41 kB
import time
import traceback
from dataclasses import dataclass, field
import gradio as gr
import librosa
import numpy as np
import soundfile as sf
import spaces
import torch
import xxhash
from datasets import Audio
from transformers import AutoModel
import io
if gr.NO_RELOAD:
diva_model = AutoModel.from_pretrained(
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True
)
resampler = Audio(sampling_rate=16_000)
@spaces.GPU
@torch.no_grad
def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
sr, y = audio_input
x = xxhash.xxh32(bytes(y)).hexdigest()
y = y.astype(np.float32)
y /= np.max(np.abs(y))
a = resampler.decode_example(
resampler.encode_example({"array": y, "sampling_rate": sr})
)
yield from diva_model.generate_stream(
a["array"],
(
"Your name is DiVA, which stands for Distilled Voice Assistant. You were trained with early-fusion training to merge OpenAI's Whisper and Meta AI's Llama 3 8B to provide end-to-end voice processing. You should give brief and helpful answers, in a conversational style. The user is talking to you with their voice and you are responding with text."
if prev_outs == None
else None
),
do_sample=do_sample,
max_new_tokens=256,
init_outputs=prev_outs,
return_outputs=True,
)
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
stopped: bool = False
conversation: list = field(default_factory=list)
model_outs: any = None
def process_audio(audio: tuple, state: AppState):
return audio, state
@spaces.GPU(duration=40, progress=gr.Progress(track_tqdm=True))
def response(state: AppState, audio: tuple):
if not audio:
return AppState()
state.stream = audio[1]
state.sampling_rate = audio[0]
file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
sf.write(file_name, state.stream, state.sampling_rate, format="wav")
state.conversation.append(
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}}
)
start = False
for resp, outs in diva_audio(
(state.sampling_rate, state.stream), prev_outs=state.model_outs
):
print(resp)
if not start:
state.conversation.append({"role": "assistant", "content": resp})
start = True
else:
state.conversation[-1]["content"] = resp
yield state, state.conversation
yield AppState(conversation=state.conversation, model_outs=outs), state.conversation
def start_recording_user(state: AppState):
return None
theme = gr.themes.Soft(
primary_hue=gr.themes.Color(
c100="#82000019",
c200="#82000033",
c300="#8200004c",
c400="#82000066",
c50="#8200007f",
c500="#8200007f",
c600="#82000099",
c700="#820000b2",
c800="#820000cc",
c900="#820000e5",
c950="#820000f2",
),
secondary_hue="rose",
neutral_hue="stone",
)
js = """
async function main() {
const script1 = document.createElement("script");
script1.src = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.14.0/dist/ort.js";
document.head.appendChild(script1)
const script2 = document.createElement("script");
script2.onload = async () => {
console.log("vad loaded") ;
var record = document.querySelector('.record-button');
record.textContent = "Just Start Talking!"
record.style = "width: 11vw"
const myvad = await vad.MicVAD.new({
onSpeechStart: () => {
var record = document.querySelector('.record-button');
if (record != null) {
console.log(record);
record.click();
}
},
onSpeechEnd: (audio) => {
var stop = document.querySelector('.stop-button');
if (stop != null) {
console.log(stop);
stop.click();
}
}
})
myvad.start()
}
script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.7/dist/bundle.min.js";
script1.onload = () => {
console.log("onnx loaded")
document.head.appendChild(script2)
};
}
"""
js_reset = """
() => {
var record = document.querySelector('.record-button');
record.textContent = "Just Start Talking!"
record.style = "width: 11vw"
}
"""
with gr.Blocks(theme=theme, js=js) as demo:
with gr.Row():
input_audio = gr.Audio(
label="Input Audio",
sources=["microphone"],
type="numpy",
streaming=False,
)
with gr.Row():
chatbot = gr.Chatbot(label="Conversation", type="messages")
state = gr.State(value=AppState())
stream = input_audio.start_recording(
process_audio,
[input_audio, state],
[input_audio, state],
)
respond = input_audio.stop_recording(
response, [state, input_audio], [state, chatbot]
)
restart = respond.then(start_recording_user, [state], [input_audio]).then(
lambda state: state, state, state, js=js_reset
)
cancel = gr.Button("Restart Conversation", variant="stop")
cancel.click(
lambda: (AppState(stopped=True), gr.Audio(recording=False)),
None,
[state, input_audio],
cancels=[respond, restart],
)
if __name__ == "__main__":
demo.launch()