|
import io |
|
import os |
|
import torch |
|
from transformers import ( |
|
AutomaticSpeechRecognitionPipeline, |
|
WhisperForConditionalGeneration, |
|
WhisperTokenizer, |
|
WhisperProcessor, |
|
) |
|
from peft import PeftModel, PeftConfig |
|
import speech_recognition as sr |
|
from datetime import datetime, timedelta |
|
from queue import Queue |
|
from tempfile import NamedTemporaryFile |
|
from time import sleep |
|
from sys import platform |
|
|
|
|
|
|
|
def main(): |
|
|
|
peft_model_id = "DuyTa/Vietnamese_ASR" |
|
language = "Vietnamese" |
|
task = "transcribe" |
|
default_energy_threshold = 900 |
|
default_record_timeout = 0.6 |
|
default_phrase_timeout = 3 |
|
|
|
|
|
phrase_time = None |
|
|
|
last_sample = bytes() |
|
|
|
data_queue = Queue() |
|
|
|
recorder = sr.Recognizer() |
|
recorder.energy_threshold = default_energy_threshold |
|
|
|
recorder.dynamic_energy_threshold = False |
|
|
|
source = sr.Microphone(sample_rate=16000) |
|
|
|
|
|
peft_config = PeftConfig.from_pretrained(peft_model_id) |
|
model = WhisperForConditionalGeneration.from_pretrained( |
|
peft_config.base_model_name_or_path |
|
) |
|
model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
|
model.to("cuda:0") |
|
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task) |
|
pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, batch_size=8, torch_dtype=torch.float32, device="cuda:0") |
|
|
|
|
|
|
|
|
|
record_timeout = default_record_timeout |
|
phrase_timeout = default_phrase_timeout |
|
|
|
temp_file = NamedTemporaryFile().name |
|
transcription = [''] |
|
|
|
with source: |
|
recorder.adjust_for_ambient_noise(source) |
|
|
|
def record_callback(_, audio:sr.AudioData) -> None: |
|
""" |
|
Threaded callback function to receive audio data when recordings finish. |
|
audio: An AudioData containing the recorded bytes. |
|
""" |
|
|
|
data = audio.get_raw_data() |
|
data_queue.put(data) |
|
|
|
|
|
|
|
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) |
|
|
|
print("Model loaded.\n") |
|
|
|
while True: |
|
try: |
|
now = datetime.utcnow() |
|
|
|
if not data_queue.empty(): |
|
phrase_complete = False |
|
|
|
|
|
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): |
|
last_sample = bytes() |
|
phrase_complete = True |
|
|
|
phrase_time = now |
|
|
|
|
|
while not data_queue.empty(): |
|
data = data_queue.get() |
|
last_sample += data |
|
|
|
|
|
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH) |
|
wav_data = io.BytesIO(audio_data.get_wav_data()) |
|
|
|
|
|
with open(temp_file, 'w+b') as f: |
|
f.write(wav_data.read()) |
|
|
|
|
|
text = pipe(temp_file, chunk_length_s=30, return_timestamps=False, generate_kwargs={"language": language, "task": task})["text"] |
|
|
|
|
|
|
|
|
|
if phrase_complete: |
|
transcription.append(text) |
|
else: |
|
transcription[-1] = text |
|
|
|
|
|
os.system('cls' if os.name == 'nt' else 'clear') |
|
for line in transcription: |
|
print(line) |
|
|
|
print('', end='', flush=True) |
|
|
|
|
|
sleep(0.25) |
|
except KeyboardInterrupt: |
|
break |
|
|
|
print("\n\nTranscription:") |
|
for line in transcription: |
|
print(line) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|