Spaces:
Running
on
A10G
Running
on
A10G
import os | |
os.environ["MODELSCOPE_CACHE"] = ".cache/" | |
import string | |
import time | |
from threading import Lock | |
import librosa | |
import numpy as np | |
import opencc | |
import torch | |
from faster_whisper import WhisperModel | |
t2s_converter = opencc.OpenCC("t2s") | |
def load_model(*, device="cuda"): | |
model = WhisperModel( | |
"medium", | |
device=device, | |
compute_type="float16", | |
download_root="faster_whisper", | |
) | |
print("faster_whisper loaded!") | |
return model | |
def batch_asr_internal(model: WhisperModel, audios, sr): | |
resampled_audios = [] | |
for audio in audios: | |
if isinstance(audio, np.ndarray): | |
audio = torch.from_numpy(audio).float() | |
if audio.dim() > 1: | |
audio = audio.squeeze() | |
assert audio.dim() == 1 | |
audio_np = audio.numpy() | |
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000) | |
resampled_audios.append(resampled_audio) | |
trans_results = [] | |
for resampled_audio in resampled_audios: | |
segments, info = model.transcribe( | |
resampled_audio, | |
language=None, | |
beam_size=5, | |
initial_prompt="Punctuation is needed in any language.", | |
) | |
trans_results.append(list(segments)) | |
results = [] | |
for trans_res, audio in zip(trans_results, audios): | |
duration = len(audio) / sr * 1000 | |
huge_gap = False | |
max_gap = 0.0 | |
text = None | |
last_tr = None | |
for tr in trans_res: | |
delta = tr.text.strip() | |
if tr.id > 1: | |
max_gap = max(tr.start - last_tr.end, max_gap) | |
text += delta | |
else: | |
text = delta | |
last_tr = tr | |
if max_gap > 3.0: | |
huge_gap = True | |
break | |
sim_text = t2s_converter.convert(text) | |
results.append( | |
{ | |
"text": sim_text, | |
"duration": duration, | |
"huge_gap": huge_gap, | |
} | |
) | |
return results | |
global_lock = Lock() | |
def batch_asr(model, audios, sr): | |
return batch_asr_internal(model, audios, sr) | |
def is_chinese(text): | |
return True | |
def calculate_wer(text1, text2, debug=False): | |
chars1 = remove_punctuation(text1) | |
chars2 = remove_punctuation(text2) | |
m, n = len(chars1), len(chars2) | |
if m > n: | |
chars1, chars2 = chars2, chars1 | |
m, n = n, m | |
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...] | |
curr = [0] * (m + 1) | |
for j in range(1, n + 1): | |
curr[0] = j | |
for i in range(1, m + 1): | |
if chars1[i - 1] == chars2[j - 1]: | |
curr[i] = prev[i - 1] | |
else: | |
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1 | |
prev, curr = curr, prev | |
edits = prev[m] | |
tot = max(len(chars1), len(chars2)) | |
wer = edits / tot | |
if debug: | |
print(" gt: ", chars1) | |
print(" pred: ", chars2) | |
print(" edits/tot = wer: ", edits, "/", tot, "=", wer) | |
return wer | |
def remove_punctuation(text): | |
chinese_punctuation = ( | |
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—" | |
'‛""„‟…‧﹏' | |
) | |
all_punctuation = string.punctuation + chinese_punctuation | |
translator = str.maketrans("", "", all_punctuation) | |
text_without_punctuation = text.translate(translator) | |
return text_without_punctuation | |
if __name__ == "__main__": | |
model = load_model() | |
audios = [ | |
librosa.load("44100.wav", sr=44100)[0], | |
librosa.load("lengyue.wav", sr=44100)[0], | |
] | |
print(np.array(audios[0])) | |
print(batch_asr(model, audios, 44100)) | |
start_time = time.time() | |
for _ in range(10): | |
print(batch_asr(model, audios, 44100)) | |
print("Time taken:", time.time() - start_time) | |