Spaces:
Runtime error
Runtime error
Vaibhav Srivastav
commited on
Commit
·
b8af00e
1
Parent(s):
851eb15
adding decoding w lm
Browse files- 4gram_small.arpa.gz +3 -0
- app.py +24 -2
4gram_small.arpa.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4c4fe64751abecdeb7040fe6ed7f2440c2d3f36ed35c43e3510f7cf95578f2a
|
3 |
+
size 18358716
|
app.py
CHANGED
@@ -42,6 +42,28 @@ def predict_and_ctc_decode(input_file, model_name):
|
|
42 |
|
43 |
return transcribed_text
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def predict_and_greedy_decode(input_file, model_name):
|
46 |
processor, model = return_processor_and_model(model_name)
|
47 |
speech = load_and_fix_data(input_file)
|
@@ -57,12 +79,12 @@ def predict_and_greedy_decode(input_file, model_name):
|
|
57 |
return transcribed_text
|
58 |
|
59 |
def return_all_predictions(input_file, model_name):
|
60 |
-
return predict_and_ctc_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
|
61 |
|
62 |
|
63 |
gr.Interface(return_all_predictions,
|
64 |
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
|
65 |
-
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Greedy decoding")],
|
66 |
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
|
67 |
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
|
68 |
layout = "horizontal",
|
|
|
42 |
|
43 |
return transcribed_text
|
44 |
|
45 |
+
def predict_and_ctc_lm_decode(input_file, model_name):
|
46 |
+
processor, model = return_processor_and_model(model_name)
|
47 |
+
speech = load_and_fix_data(input_file)
|
48 |
+
|
49 |
+
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
|
50 |
+
logits = model(input_values).logits.cpu().detach().numpy()[0]
|
51 |
+
|
52 |
+
vocab_list = list(processor.tokenizer.get_vocab().keys())
|
53 |
+
vocab_dict = processor.tokenizer.get_vocab()
|
54 |
+
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
|
55 |
+
|
56 |
+
decoder = build_ctcdecoder(
|
57 |
+
list(sorted_dict.keys()),
|
58 |
+
"4gram_small.arpa.gz",
|
59 |
+
)
|
60 |
+
|
61 |
+
pred = decoder.decode(logits)
|
62 |
+
|
63 |
+
transcribed_text = fix_transcription_casing(pred.lower())
|
64 |
+
|
65 |
+
return transcribed_text
|
66 |
+
|
67 |
def predict_and_greedy_decode(input_file, model_name):
|
68 |
processor, model = return_processor_and_model(model_name)
|
69 |
speech = load_and_fix_data(input_file)
|
|
|
79 |
return transcribed_text
|
80 |
|
81 |
def return_all_predictions(input_file, model_name):
|
82 |
+
return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
|
83 |
|
84 |
|
85 |
gr.Interface(return_all_predictions,
|
86 |
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
|
87 |
+
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
|
88 |
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
|
89 |
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
|
90 |
layout = "horizontal",
|