import gradio as gr
import websockets
import asyncio
import json
import base64
async def process_audio_stream(audio_path, max_tokens):
"""
Process audio with streaming response via WebSocket
"""
if not audio_path:
yield "Please upload or record an audio file first."
return
try:
# Read audio file and convert to base64 bytes
with open(audio_path, 'rb') as f:
audio_bytes = f.read()
base64_bytes = base64.b64encode(audio_bytes)
# Connect to WebSocket
async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-audio/') as websocket:
# Send binary base64 audio data as bytes
await websocket.send(base64_bytes) # Send the raw base64 bytes
# Send parameters as JSON string
await websocket.send(json.dumps({
"prompt": "",
"max_tokens": max_tokens
}))
# Initialize response
response = ""
# Receive streaming response
async for message in websocket:
try:
data = json.loads(message)
if data["status"] == "generating":
response += data["token"]
yield response
elif data["status"] == "complete":
break
elif data["status"] == "error":
yield f"Error: {data['error']}"
break
except json.JSONDecodeError:
continue
except Exception as e:
yield f"Error connecting to server: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=process_audio_stream,
inputs=[
gr.Audio(
type="filepath",
label="Upload or Record Audio",
sources=["upload", "microphone"]
),
gr.Slider(
minimum=50,
maximum=200,
value=50,
step=1,
label="Max Tokens"
)
],
outputs=gr.Textbox(label="Response", interactive=False),
title="NEXA OmniAudio-2.6B",
description=f"""
OmniAudio-2.6B is a compact audio-language model optimized for edge deployment.
Model Repo: NexaAIDev/OmniAudio-2.6B
Blog: OmniAudio-2.6B Blog
Upload an audio file and optionally provide a prompt to analyze the audio content.""",
examples=[
["example_audios/voice_qa.mp3", 200],
["example_audios/voice_in_conversation.mp3", 200],
["example_audios/creative_content_generation.mp3", 200],
["example_audios/record_summary.mp3", 200],
["example_audios/change_tone.mp3", 200],
]
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)