Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,033 Bytes
69e8a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
os.environ["MODELSCOPE_CACHE"] = ".cache/"
import string
import time
from threading import Lock
import librosa
import numpy as np
import opencc
import torch
from faster_whisper import WhisperModel
t2s_converter = opencc.OpenCC("t2s")
def load_model(*, device="cuda"):
model = WhisperModel(
"medium",
device=device,
compute_type="float16",
download_root="faster_whisper",
)
print("faster_whisper loaded!")
return model
@torch.no_grad()
def batch_asr_internal(model: WhisperModel, audios, sr):
resampled_audios = []
for audio in audios:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if audio.dim() > 1:
audio = audio.squeeze()
assert audio.dim() == 1
audio_np = audio.numpy()
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
resampled_audios.append(resampled_audio)
trans_results = []
for resampled_audio in resampled_audios:
segments, info = model.transcribe(
resampled_audio,
language=None,
beam_size=5,
initial_prompt="Punctuation is needed in any language.",
)
trans_results.append(list(segments))
results = []
for trans_res, audio in zip(trans_results, audios):
duration = len(audio) / sr * 1000
huge_gap = False
max_gap = 0.0
text = None
last_tr = None
for tr in trans_res:
delta = tr.text.strip()
if tr.id > 1:
max_gap = max(tr.start - last_tr.end, max_gap)
text += delta
else:
text = delta
last_tr = tr
if max_gap > 3.0:
huge_gap = True
break
sim_text = t2s_converter.convert(text)
results.append(
{
"text": sim_text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results
global_lock = Lock()
def batch_asr(model, audios, sr):
return batch_asr_internal(model, audios, sr)
def is_chinese(text):
return True
def calculate_wer(text1, text2, debug=False):
chars1 = remove_punctuation(text1)
chars2 = remove_punctuation(text2)
m, n = len(chars1), len(chars2)
if m > n:
chars1, chars2 = chars2, chars1
m, n = n, m
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
curr = [0] * (m + 1)
for j in range(1, n + 1):
curr[0] = j
for i in range(1, m + 1):
if chars1[i - 1] == chars2[j - 1]:
curr[i] = prev[i - 1]
else:
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
prev, curr = curr, prev
edits = prev[m]
tot = max(len(chars1), len(chars2))
wer = edits / tot
if debug:
print(" gt: ", chars1)
print(" pred: ", chars2)
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
return wer
def remove_punctuation(text):
chinese_punctuation = (
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
'‛""„‟…‧﹏'
)
all_punctuation = string.punctuation + chinese_punctuation
translator = str.maketrans("", "", all_punctuation)
text_without_punctuation = text.translate(translator)
return text_without_punctuation
if __name__ == "__main__":
model = load_model()
audios = [
librosa.load("44100.wav", sr=44100)[0],
librosa.load("lengyue.wav", sr=44100)[0],
]
print(np.array(audios[0]))
print(batch_asr(model, audios, 44100))
start_time = time.time()
for _ in range(10):
print(batch_asr(model, audios, 44100))
print("Time taken:", time.time() - start_time)
|