import base64 import io import tempfile import time import traceback from dataclasses import dataclass, field from queue import Queue from threading import Thread, Event import gradio as gr import librosa import numpy as np import requests from gradio_webrtc import StreamHandler, WebRTC from huggingface_hub import snapshot_download from pydub import AudioSegment import librosa from utils.vad import get_speech_timestamps, collect_chunks, VadOptions import tempfile # from server import serve from utils.vad import VadOptions, collect_chunks, get_speech_timestamps from server import serve repo_id = "gpt-omni/mini-omni" snapshot_download(repo_id, local_dir="./checkpoint", revision="main") IP = "0.0.0.0" PORT = 60808 thread = Thread(target=serve, daemon=True) thread.start() API_URL = "http://0.0.0.0:60808/chat" #API_URL = "https://freddyaboulton-omni-backend.hf.space/chat" # recording parameters IN_CHANNELS = 1 IN_RATE = 24000 IN_CHUNK = 1024 IN_SAMPLE_WIDTH = 2 VAD_STRIDE = 0.5 # playing parameters OUT_CHANNELS = 1 OUT_RATE = 24000 OUT_SAMPLE_WIDTH = 2 OUT_CHUNK = 20 * 4096 def run_vad(ori_audio, sr): _st = time.time() try: audio = ori_audio audio = audio.astype(np.float32) / 32768.0 sampling_rate = 16000 if sr != sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) vad_parameters = {} vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) duration_after_vad = audio.shape[0] / sampling_rate if sr != sampling_rate: # resample to original sampling rate vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) else: vad_audio = audio vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) vad_audio_bytes = vad_audio.tobytes() return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) except Exception as e: msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" print(msg) return -1, ori_audio, round(time.time() - _st, 4) def warm_up(): frames = np.zeros((1, 1600)) # 1024 frames of 2 bytes each _, frames, tcost = run_vad(frames, 16000) print(f"warm up done, time_cost: {tcost:.3f} s") warm_up() @dataclass class AppState: stream: np.ndarray | None = None sampling_rate: int = 0 pause_detected: bool = False started_talking: bool = False responding: bool = False stopped: bool = False buffer: np.ndarray | None = None def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: """Take in the stream, determine if a pause happened""" duration = len(audio) / sampling_rate dur_vad, _, _ = run_vad(audio, sampling_rate) if duration >= 0.60: if dur_vad > 0.2 and not state.started_talking: print("started talking") state.started_talking = True if state.started_talking: if state.stream is None: state.stream = audio else: state.stream = np.concatenate((state.stream, audio)) state.buffer = None if dur_vad < 0.1 and state.started_talking: segment = AudioSegment( state.stream.tobytes(), frame_rate=sampling_rate, sample_width=audio.dtype.itemsize, channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: segment.export(f.name, format="wav") print("input file written", f.name) return True return False def speaking(audio_bytes: str): base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") files = {"audio": base64_encoded} byte_buffer = b"" with requests.post(API_URL, json=files, stream=True) as response: try: for chunk in response.iter_content(chunk_size=OUT_CHUNK): if chunk: # Create an audio segment from the numpy array byte_buffer += chunk audio_segment = AudioSegment( chunk + b"\x00" if len(chunk) % 2 != 0 else chunk, frame_rate=OUT_RATE, sample_width=OUT_SAMPLE_WIDTH, channels=OUT_CHANNELS, ) # Export the audio segment to a numpy array audio_np = np.array(audio_segment.get_array_of_samples()) yield audio_np.reshape(1, -1) all_output_audio = AudioSegment( byte_buffer, frame_rate=OUT_RATE, sample_width=OUT_SAMPLE_WIDTH, channels=1, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: all_output_audio.export(f.name, format="wav") print("output file written", f.name) except Exception as e: raise gr.Error(f"Error during audio streaming: {e}") def process_audio(audio: tuple, state: AppState) -> None: frame_rate, array = audio array = np.squeeze(array) if not state.sampling_rate: state.sampling_rate = frame_rate if state.buffer is None: state.buffer = array else: state.buffer = np.concatenate((state.buffer, array)) pause_detected = determine_pause(state.buffer, state.sampling_rate, state) state.pause_detected = pause_detected def response(state: AppState): if not state.pause_detected and not state.started_talking: return None audio_buffer = io.BytesIO() segment = AudioSegment( state.stream.tobytes(), frame_rate=state.sampling_rate, sample_width=state.stream.dtype.itemsize, channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), ) segment.export(audio_buffer, format="wav") for numpy_array in speaking(audio_buffer.getvalue()): yield (OUT_RATE, numpy_array, "mono") class OmniHandler(StreamHandler): def __init__(self) -> None: super().__init__(expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480) self.chunk_queue = Queue() self.state = AppState() self.generator = None self.duration = 0 def receive(self, frame: tuple[int, np.ndarray]) -> None: if self.state.responding: return process_audio(frame, self.state) if self.state.pause_detected: self.chunk_queue.put(True) def reset(self): self.generator = None self.state = AppState() self.duration = 0 def emit(self): if not self.generator: self.chunk_queue.get() self.state.responding = True self.generator = response(self.state) try: return next(self.generator) except StopIteration: self.reset() with gr.Blocks() as demo: gr.HTML( """