Spaces:
Running
on
A10G
Running
on
A10G
import base64 | |
import ctypes | |
import io | |
import json | |
import os | |
import struct | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import AsyncGenerator, Union | |
import httpx | |
import numpy as np | |
import ormsgpack | |
import soundfile as sf | |
from .schema import ( | |
ServeMessage, | |
ServeRequest, | |
ServeTextPart, | |
ServeVQGANDecodeRequest, | |
ServeVQGANEncodeRequest, | |
ServeVQPart, | |
) | |
class CustomAudioFrame: | |
def __init__(self, data, sample_rate, num_channels, samples_per_channel): | |
if len(data) < num_channels * samples_per_channel * ctypes.sizeof( | |
ctypes.c_int16 | |
): | |
raise ValueError( | |
"data length must be >= num_channels * samples_per_channel * sizeof(int16)" | |
) | |
self._data = bytearray(data) | |
self._sample_rate = sample_rate | |
self._num_channels = num_channels | |
self._samples_per_channel = samples_per_channel | |
def data(self): | |
return memoryview(self._data).cast("h") | |
def sample_rate(self): | |
return self._sample_rate | |
def num_channels(self): | |
return self._num_channels | |
def samples_per_channel(self): | |
return self._samples_per_channel | |
def duration(self): | |
return self.samples_per_channel / self.sample_rate | |
def __repr__(self): | |
return ( | |
f"CustomAudioFrame(sample_rate={self.sample_rate}, " | |
f"num_channels={self.num_channels}, " | |
f"samples_per_channel={self.samples_per_channel}, " | |
f"duration={self.duration:.3f})" | |
) | |
class FishE2EEventType(Enum): | |
SPEECH_SEGMENT = 1 | |
TEXT_SEGMENT = 2 | |
END_OF_TEXT = 3 | |
END_OF_SPEECH = 4 | |
ASR_RESULT = 5 | |
USER_CODES = 6 | |
class FishE2EEvent: | |
type: FishE2EEventType | |
frame: np.ndarray = None | |
text: str = None | |
vq_codes: list[list[int]] = None | |
client = httpx.AsyncClient( | |
timeout=None, | |
limits=httpx.Limits( | |
max_connections=None, | |
max_keepalive_connections=None, | |
keepalive_expiry=None, | |
), | |
) | |
class FishE2EAgent: | |
def __init__(self): | |
self.llm_url = "http://localhost:8080/v1/chat" | |
self.vqgan_url = "http://localhost:8080" | |
self.client = httpx.AsyncClient(timeout=None) | |
async def get_codes(self, audio_data, sample_rate): | |
audio_buffer = io.BytesIO() | |
sf.write(audio_buffer, audio_data, sample_rate, format="WAV") | |
audio_buffer.seek(0) | |
# Step 1: Encode audio using VQGAN | |
encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()]) | |
encode_request_bytes = ormsgpack.packb( | |
encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC | |
) | |
encode_response = await self.client.post( | |
f"{self.vqgan_url}/v1/vqgan/encode", | |
data=encode_request_bytes, | |
headers={"Content-Type": "application/msgpack"}, | |
) | |
encode_response_data = ormsgpack.unpackb(encode_response.content) | |
codes = encode_response_data["tokens"][0] | |
return codes | |
async def stream( | |
self, | |
system_audio_data: np.ndarray | None, | |
user_audio_data: np.ndarray | None, | |
sample_rate: int, | |
num_channels: int, | |
chat_ctx: dict | None = None, | |
) -> AsyncGenerator[bytes, None]: | |
if system_audio_data is not None: | |
sys_codes = await self.get_codes(system_audio_data, sample_rate) | |
else: | |
sys_codes = None | |
if user_audio_data is not None: | |
user_codes = await self.get_codes(user_audio_data, sample_rate) | |
# Step 2: Prepare LLM request | |
if chat_ctx is None: | |
sys_parts = [ | |
ServeTextPart( | |
text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。' | |
), | |
] | |
if system_audio_data is not None: | |
sys_parts.append(ServeVQPart(codes=sys_codes)) | |
chat_ctx = { | |
"messages": [ | |
ServeMessage( | |
role="system", | |
parts=sys_parts, | |
), | |
], | |
} | |
else: | |
if chat_ctx["added_sysaudio"] is False and sys_codes: | |
chat_ctx["added_sysaudio"] = True | |
chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes)) | |
prev_messages = chat_ctx["messages"].copy() | |
if user_audio_data is not None: | |
yield FishE2EEvent( | |
type=FishE2EEventType.USER_CODES, | |
vq_codes=user_codes, | |
) | |
else: | |
user_codes = None | |
request = ServeRequest( | |
messages=prev_messages | |
+ ( | |
[ | |
ServeMessage( | |
role="user", | |
parts=[ServeVQPart(codes=user_codes)], | |
) | |
] | |
if user_codes | |
else [] | |
), | |
streaming=True, | |
num_samples=1, | |
) | |
# Step 3: Stream LLM response and decode audio | |
buffer = b"" | |
vq_codes = [] | |
current_vq = False | |
async def decode_send(): | |
nonlocal current_vq | |
nonlocal vq_codes | |
data = np.concatenate(vq_codes, axis=1).tolist() | |
# Decode VQ codes to audio | |
decode_request = ServeVQGANDecodeRequest(tokens=[data]) | |
decode_response = await self.client.post( | |
f"{self.vqgan_url}/v1/vqgan/decode", | |
data=ormsgpack.packb( | |
decode_request, | |
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | |
), | |
headers={"Content-Type": "application/msgpack"}, | |
) | |
decode_data = ormsgpack.unpackb(decode_response.content) | |
# Convert float16 audio data to int16 | |
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16) | |
audio_data = (audio_data * 32768).astype(np.int16).tobytes() | |
audio_frame = CustomAudioFrame( | |
data=audio_data, | |
samples_per_channel=len(audio_data) // 2, | |
sample_rate=44100, | |
num_channels=1, | |
) | |
yield FishE2EEvent( | |
type=FishE2EEventType.SPEECH_SEGMENT, | |
frame=audio_frame, | |
vq_codes=data, | |
) | |
current_vq = False | |
vq_codes = [] | |
async with self.client.stream( | |
"POST", | |
self.llm_url, | |
data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), | |
headers={"Content-Type": "application/msgpack"}, | |
) as response: | |
async for chunk in response.aiter_bytes(): | |
buffer += chunk | |
while len(buffer) >= 4: | |
read_length = struct.unpack("I", buffer[:4])[0] | |
if len(buffer) < 4 + read_length: | |
break | |
body = buffer[4 : 4 + read_length] | |
buffer = buffer[4 + read_length :] | |
data = ormsgpack.unpackb(body) | |
if data["delta"] and data["delta"]["part"]: | |
if current_vq and data["delta"]["part"]["type"] == "text": | |
async for event in decode_send(): | |
yield event | |
if data["delta"]["part"]["type"] == "text": | |
yield FishE2EEvent( | |
type=FishE2EEventType.TEXT_SEGMENT, | |
text=data["delta"]["part"]["text"], | |
) | |
elif data["delta"]["part"]["type"] == "vq": | |
vq_codes.append(np.array(data["delta"]["part"]["codes"])) | |
current_vq = True | |
if current_vq and vq_codes: | |
async for event in decode_send(): | |
yield event | |
yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT) | |
yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH) | |
# Example usage: | |
async def main(): | |
import torchaudio | |
agent = FishE2EAgent() | |
# Replace this with actual audio data loading | |
with open("uz_story_en.m4a", "rb") as f: | |
audio_data = f.read() | |
audio_data, sample_rate = torchaudio.load("uz_story_en.m4a") | |
audio_data = (audio_data.numpy() * 32768).astype(np.int16) | |
stream = agent.stream(audio_data, sample_rate, 1) | |
if os.path.exists("audio_segment.wav"): | |
os.remove("audio_segment.wav") | |
async for event in stream: | |
if event.type == FishE2EEventType.SPEECH_SEGMENT: | |
# Handle speech segment (e.g., play audio or save to file) | |
with open("audio_segment.wav", "ab+") as f: | |
f.write(event.frame.data) | |
elif event.type == FishE2EEventType.ASR_RESULT: | |
print(event.text, flush=True) | |
elif event.type == FishE2EEventType.TEXT_SEGMENT: | |
print(event.text, flush=True, end="") | |
elif event.type == FishE2EEventType.END_OF_TEXT: | |
print("\nEnd of text reached.") | |
elif event.type == FishE2EEventType.END_OF_SPEECH: | |
print("End of speech reached.") | |
if __name__ == "__main__": | |
import asyncio | |
asyncio.run(main()) | |