File size: 9,701 Bytes
9d710fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import argparse
import json
import threading
import time
from pathlib import Path
from typing import List

import websocket
import os

import librosa
import numpy as np

# Define the default WebSocket endpoint
DEFAULT_WS_URL = "ws://localhost:8000/v1/ws_transcribe_streaming"


def parse_arguments():
    parser = argparse.ArgumentParser(description="Stream audio to the transcription WebSocket endpoint.")
    parser.add_argument("audio_file", help="Path to the input audio file.")
    parser.add_argument("--url", default=DEFAULT_WS_URL, help="WebSocket endpoint URL.")
    parser.add_argument("--model", type=str, help="Model name to use for transcription.")
    parser.add_argument("--language", type=str, help="Language code for transcription.")
    parser.add_argument(
        "--response_format",
        type=str,
        default="verbose_json",
        choices=["text", "json", "verbose_json"],
        help="Response format.",
    )
    parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for transcription.")
    parser.add_argument("--vad_filter", action="store_true", help="Enable voice activity detection filter.")
    parser.add_argument("--chunk_duration", type=float, default=1.0, help="Duration of each audio chunk in seconds.")
    return parser.parse_args()


# def preprocess_audio(audio_file, target_sr=16000):
#     """
#     Load the audio file, convert to mono 16kHz, and return the audio data.
#     """
#     if audio_file.endswith(".mp3"):
#         # Convert MP3 to WAV using ffmpeg
#         wav_file = audio_file.replace(".mp3", ".wav")
#         if not os.path.exists(wav_file):
#             command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
#             print(f"Converting MP3 to WAV: {command}")
#             os.system(command)
#         audio_file = wav_file
#
#     print(f"Loading audio file {audio_file}")
#     audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
#     return audio_data, sr
#
# def chunk_audio(audio_data, sr, chunk_duration):
#     """
#     Split the audio data into chunks of specified duration.
#     """
#     chunk_samples = int(chunk_duration * sr)
#     total_samples = len(audio_data)
#     chunks = [
#         audio_data[i:i + chunk_samples]
#         for i in range(0, total_samples, chunk_samples)
#     ]
#     print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
#     return chunks


def read_audio_in_chunks(audio_file, target_sr=16000, chunk_duration=1) -> List[np.ndarray]:
    """
    Reads a 16kHz mono audio file in 1-second chunks and returns them as little-endian 16-bit integer arrays.

    Args:
        file_path (str): Path to the audio file.
        expected_sr (int): Expected sample rate (16000 by default).
        expected_mono (bool): Expect the file to be mono (True by default).
        chunk_duration (int): Duration of each chunk in seconds (1 second by default).

    Returns:
        List of numpy arrays: Each array is a 1-second chunk of the audio as 16-bit integers.

    Raises:
        ValueError: If the audio file's sample rate or number of channels doesn't match expectations.
    """
    if not str(audio_file).endswith(".wav"):
        # Convert MP3 to WAV using ffmpeg
        wav_file = Path(audio_file).with_suffix(".wav")
        if not wav_file.exists():
            command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
            print(f"Converting MP3 to WAV: {command}")
            os.system(command)
        audio_file = wav_file

    # Load the audio file
    audio_data, sr = librosa.load(audio_file, sr=None, mono=True)

    # Raise an exception if the sample rate doesn't match
    if sr != target_sr:
        raise ValueError(f"Unexpected sample rate {sr}. Expected {target_sr}.")

    # Convert the audio data to 16-bit PCM (little-endian)
    audio_data_int16 = (audio_data * 32767).astype(np.int16)

    # Check if the current byte order is little-endian
    if audio_data_int16.dtype.byteorder == '>' or (
            audio_data_int16.dtype.byteorder == '=' and np.dtype(np.int16).byteorder == '>'):
        print("Byte swap performed to convert to little-endian.")
        # Ensure little-endian format (if the current format is big-endian)
        audio_data_little_endian = audio_data_int16.byteswap().newbyteorder('L')
    else:
        print("No byte swap needed. Already little-endian.")
        audio_data_little_endian = audio_data_int16

    # Calculate the number of samples per chunk
    samples_per_chunk = target_sr * chunk_duration

    # Split the audio into chunks
    chunks = [
        audio_data_little_endian[i:i + samples_per_chunk]
        for i in range(0, len(audio_data_little_endian), samples_per_chunk)
    ]

    return chunks


def build_query_params(args):
    """
    Build the query parameters for the WebSocket URL based on command-line arguments.
    """
    params = {}
    if args.model:
        params["model"] = args.model
    if args.language:
        params["language"] = args.language
    if args.response_format:
        params["response_format"] = args.response_format
    if args.temperature is not None:
        params["temperature"] = str(args.temperature)
    if args.vad_filter:
        params["vad_filter"] = "true"
    return params


def websocket_url_with_params(base_url, params):
    """
    Append query parameters to the WebSocket URL.
    """
    from urllib.parse import urlencode

    if params:
        query_string = urlencode(params)
        url = f"{base_url}?{query_string}"
    else:
        url = base_url
    return url


def on_message(ws, message):
    """
    Callback function when a message is received from the server.
    """
    try:
        data = json.loads(message)
        # Accumulate transcriptions
        if ws.args.response_format == "verbose_json":
            segments = data.get('segments', [])
            ws.transcriptions.extend(segments)
            for segment in segments:
                print(f"Received segment: {segment['text']}")
        else:
            # For 'json' or 'text' format
            ws.transcriptions.append(data)
            print(f"Transcription: {data['text']}")
    except json.JSONDecodeError:
        print(f"Received non-JSON message: {message}")


def on_error(ws, error):
    """
    Callback function when an error occurs.
    """
    print(f"WebSocket error: {error}")


def on_close(ws, close_status_code, close_msg):
    """
    Callback function when the WebSocket connection is closed.
    """
    print("WebSocket connection closed")


def on_open(ws):
    """
    Callback function when the WebSocket connection is opened.
    """
    print("WebSocket connection opened")
    ws.transcriptions = []  # Initialize the list to store transcriptions


def send_audio_chunks(ws, audio_chunks, sr):
    """
    Send audio chunks to the WebSocket server.
    """
    for idx, chunk in enumerate(audio_chunks):
        # Ensure little-endian format
        audio_bytes = chunk.astype('<f4').tobytes()
        ws.send(audio_bytes, opcode=websocket.ABNF.OPCODE_BINARY)
        print(f"Sent chunk {idx + 1}/{len(audio_chunks)}")
        time.sleep(0.1)  # Small delay to simulate real-time streaming
    print("All audio chunks sent")
    # Optionally, wait to receive remaining messages
    time.sleep(2)
    ws.close()
    print("Closed WebSocket connection")



def format_timestamp(seconds):
    """
    Convert seconds to SRT timestamp format (HH:MM:SS,mmm).
    """
    total_milliseconds = int(seconds * 1000)
    hours = total_milliseconds // (3600 * 1000)
    minutes = (total_milliseconds % (3600 * 1000)) // (60 * 1000)
    secs = (total_milliseconds % (60 * 1000)) // 1000
    milliseconds = total_milliseconds % 1000
    return f"{hours:02}:{minutes:02}:{secs:02},{milliseconds:03}"


def generate_srt(transcriptions):
    """
    Generate and print SRT content from transcriptions.
    """
    print("\nGenerated SRT:")
    for idx, segment in enumerate(transcriptions, 1):
        start_time = format_timestamp(segment['start'])
        end_time = format_timestamp(segment['end'])
        text = segment['text']
        print(f"{idx}")
        print(f"{start_time} --> {end_time}")
        print(f"{text}\n")


def run_websocket_client(args):
    """
    Run the WebSocket client to stream audio and receive transcriptions.
    """
    try:
        audio_chunks = read_audio_in_chunks(args.audio_file)

        # params = build_query_params(args)
        # ws_url = websocket_url_with_params(args.url, params)
        ws_url = args.url

        ws = websocket.WebSocketApp(
            ws_url,
            on_open=on_open,
            on_message=on_message,
            on_error=on_error,
            on_close=on_close,
        )
        ws.args = args  # Attach args to ws to access inside callbacks

        # Run the WebSocket in a separate thread to allow sending and receiving simultaneously
        ws_thread = threading.Thread(target=ws.run_forever)
        ws_thread.start()

        # Wait for the connection to open
        while not ws.sock or not ws.sock.connected:
            time.sleep(0.1)

        # Send the audio chunks
        send_audio_chunks(ws, audio_chunks, 16000)
    except Exception as e:
        print(f"An error occurred: {e}")

    # Wait for the WebSocket thread to finish
    ws_thread.join()

    # Generate SRT if transcriptions are available
    if hasattr(ws, 'transcriptions') and ws.transcriptions:
        generate_srt(ws.transcriptions)
    else:
        print("No transcriptions received.")


if __name__ == "__main__":
    args = parse_arguments()
    run_websocket_client(args)