|
|
|
|
|
import argparse |
|
import io |
|
import os |
|
import speech_recognition as sr |
|
import whisperx |
|
import torch |
|
|
|
from datetime import datetime, timedelta |
|
from queue import Queue |
|
from tempfile import NamedTemporaryFile |
|
from time import sleep |
|
from sys import platform |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", default="Vietnamese_ASR/ct2ranslate", help="Size of model or the local path for model ", |
|
type=str) |
|
parser.add_argument("--non_english", action='store_true', |
|
help="Don't use the English model.") |
|
parser.add_argument("--language", default="vi", help="The language to infer the model with whisper", type=str) |
|
parser.add_argument("--device", default="cpu", |
|
help="Choose device for inference " |
|
, type=str) |
|
parser.add_argument("--energy_threshold", default=900, |
|
help="Energy level for mic to detect.", type=int) |
|
parser.add_argument("--record_timeout", default=0.6, |
|
help="How real-time the recording is in seconds.", type=float) |
|
parser.add_argument("--phrase_timeout", default=3, |
|
help="How much empty space between recordings before we " |
|
"consider it a new line in the transcription.", type=float) |
|
if 'linux' in platform: |
|
parser.add_argument("--default_microphone", default='pulse', |
|
help="Default microphone name for SpeechRecognition. " |
|
"Run this with 'list' to view available Microphones.", type=str) |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
phrase_time = None |
|
|
|
last_sample = bytes() |
|
|
|
data_queue = Queue() |
|
|
|
recorder = sr.Recognizer() |
|
recorder.energy_threshold = args.energy_threshold |
|
|
|
recorder.dynamic_energy_threshold = False |
|
|
|
|
|
|
|
if 'linux' in platform: |
|
mic_name = args.default_microphone |
|
if not mic_name or mic_name == 'list': |
|
print("Available microphone devices are: ") |
|
for index, name in enumerate(sr.Microphone.list_microphone_names()): |
|
print(f"Microphone with name \"{name}\" found") |
|
return |
|
else: |
|
for index, name in enumerate(sr.Microphone.list_microphone_names()): |
|
if mic_name in name: |
|
source = sr.Microphone(sample_rate=16000, device_index=index) |
|
break |
|
else: |
|
source = sr.Microphone(sample_rate=16000) |
|
|
|
|
|
model = args.model |
|
|
|
|
|
audio_model = whisperx.load_model(model, device=args.device, compute_type="float16", language = args.language) |
|
|
|
record_timeout = args.record_timeout |
|
phrase_timeout = args.phrase_timeout |
|
|
|
temp_file = NamedTemporaryFile().name |
|
transcription = [''] |
|
|
|
with source: |
|
recorder.adjust_for_ambient_noise(source) |
|
|
|
def record_callback(_, audio:sr.AudioData) -> None: |
|
""" |
|
Threaded callback function to recieve audio data when recordings finish. |
|
audio: An AudioData containing the recorded bytes. |
|
""" |
|
|
|
data = audio.get_raw_data() |
|
data_queue.put(data) |
|
|
|
|
|
|
|
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) |
|
|
|
|
|
print("Model loaded.\n") |
|
|
|
while True: |
|
try: |
|
now = datetime.utcnow() |
|
|
|
if not data_queue.empty(): |
|
phrase_complete = False |
|
|
|
|
|
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): |
|
last_sample = bytes() |
|
phrase_complete = True |
|
|
|
phrase_time = now |
|
|
|
|
|
while not data_queue.empty(): |
|
data = data_queue.get() |
|
last_sample += data |
|
|
|
|
|
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH) |
|
wav_data = io.BytesIO(audio_data.get_wav_data()) |
|
|
|
|
|
with open(temp_file, 'w+b') as f: |
|
f.write(wav_data.read()) |
|
|
|
|
|
result = audio_model.transcribe(temp_file, language="en",batch_size = 8) |
|
text = result['segments'][0]['text'].strip() |
|
|
|
|
|
|
|
if phrase_complete: |
|
transcription.append(text) |
|
else: |
|
transcription[-1] = text |
|
|
|
|
|
os.system('cls' if os.name=='nt' else 'clear') |
|
for line in transcription: |
|
print(line) |
|
|
|
print('', end='', flush=True) |
|
|
|
|
|
sleep(0.25) |
|
except KeyboardInterrupt: |
|
break |
|
|
|
print("\n\nTranscription:") |
|
for line in transcription: |
|
print(line) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |