freddyaboulton's picture
Update app.py
fcb6d65 verified
import logging
# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)
# Create a file handler instead of console handler
file_handler = logging.FileHandler("gradio_webrtc.log")
file_handler.setLevel(logging.DEBUG)
# Create a formatter (you might want to add timestamp to file logs)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
import base64
import io
import os
import tempfile
import time
import traceback
from dataclasses import dataclass
from threading import Event, Thread
import gradio as gr
import librosa
import numpy as np
import requests
from gradio_webrtc import ReplyOnPause, WebRTC
from huggingface_hub import snapshot_download
from pydub import AudioSegment
from twilio.rest import Client
from server import serve
# from server import serve
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
repo_id = "gpt-omni/mini-omni"
snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
IP = "0.0.0.0"
PORT = 60808
thread = Thread(target=serve, daemon=True)
thread.start()
API_URL = "http://0.0.0.0:60808/chat"
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096
def speaking(audio_bytes: bytes):
base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
files = {"audio": base64_encoded}
byte_buffer = b""
with requests.post(API_URL, json=files, stream=True) as response:
try:
for chunk in response.iter_content(chunk_size=OUT_CHUNK):
if chunk:
# Create an audio segment from the numpy array
byte_buffer += chunk
audio_segment = AudioSegment(
chunk + b"\x00" if len(chunk) % 2 != 0 else chunk,
frame_rate=OUT_RATE,
sample_width=OUT_SAMPLE_WIDTH,
channels=OUT_CHANNELS,
)
# Export the audio segment to a numpy array
audio_np = np.array(audio_segment.get_array_of_samples())
yield audio_np.reshape(1, -1)
all_output_audio = AudioSegment(
byte_buffer,
frame_rate=OUT_RATE,
sample_width=OUT_SAMPLE_WIDTH,
channels=1,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
all_output_audio.export(f.name, format="wav")
print("output file written", f.name)
except Exception as e:
raise gr.Error(f"Error during audio streaming: {e}")
def response(audio: tuple[int, np.ndarray]):
sampling_rate, audio_np = audio
audio_np = audio_np.squeeze()
audio_buffer = io.BytesIO()
segment = AudioSegment(
audio_np.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_np.dtype.itemsize,
channels=1)
segment.export(audio_buffer, format="wav")
for numpy_array in speaking(audio_buffer.getvalue()):
yield (OUT_RATE, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Omni Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="send-receive",
modality="audio",
)
audio.stream(fn=ReplyOnPause(response,
output_sample_rate=OUT_RATE,
output_frame_size=480), inputs=[audio], outputs=[audio], time_limit=60)
demo.launch(ssr_mode=False)