Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,12 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
|
7 |
nltk.download("punkt")
|
8 |
|
9 |
##
|
10 |
-
token_value = "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"
|
11 |
#Loading the pre-trained model and the tokenizer
|
12 |
-
model_name = "moro23/wav2vec-large-xls-r-300-ha-colab_4"
|
13 |
#tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name, use_auth_token=token_value)
|
14 |
-
tokenizer = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token_value)
|
15 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token_value)
|
16 |
|
17 |
def load_data(input_file):
|
18 |
|
@@ -23,7 +23,7 @@ def load_data(input_file):
|
|
23 |
speech = speech[:,0] + speech[:,1]
|
24 |
#Resampling the audio at 16KHz
|
25 |
if sample_rate !=16000:
|
26 |
-
speech = librosa.resample(speech, sample_rate,16000)
|
27 |
return speech
|
28 |
|
29 |
def correct_casing(input_sentence):
|
@@ -35,15 +35,15 @@ def asr_transcript(input_file):
|
|
35 |
|
36 |
speech = load_data(input_file)
|
37 |
#Tokenize
|
38 |
-
|
39 |
#Take logits
|
40 |
-
logits = model(input_values).logits
|
41 |
#Take argmax
|
42 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
43 |
#Get the words from predicted word ids
|
44 |
-
transcription = tokenizer.decode(predicted_ids
|
45 |
#Correcting the letter casing
|
46 |
transcription = correct_casing(transcription.lower())
|
47 |
return transcription
|
48 |
|
49 |
-
gr.Interface(asr_transcript, inputs = gr.
|
|
|
7 |
nltk.download("punkt")
|
8 |
|
9 |
##
|
10 |
+
#token_value = "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"
|
11 |
#Loading the pre-trained model and the tokenizer
|
12 |
+
#model_name = "moro23/wav2vec-large-xls-r-300-ha-colab_4"
|
13 |
#tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name, use_auth_token=token_value)
|
14 |
+
#tokenizer = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token_value)
|
15 |
+
#model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token_value)
|
16 |
|
17 |
def load_data(input_file):
|
18 |
|
|
|
23 |
speech = speech[:,0] + speech[:,1]
|
24 |
#Resampling the audio at 16KHz
|
25 |
if sample_rate !=16000:
|
26 |
+
speech = librosa.resample(speech, sample_rate, 16000)
|
27 |
return speech
|
28 |
|
29 |
def correct_casing(input_sentence):
|
|
|
35 |
|
36 |
speech = load_data(input_file)
|
37 |
#Tokenize
|
38 |
+
input_dict = tokenizer(speech, return_tensors="pt", padding=True)
|
39 |
#Take logits
|
40 |
+
logits = model(input_dict.input_values.to("cuda").logits
|
41 |
#Take argmax
|
42 |
+
predicted_ids = torch.argmax(logits, dim=-1)[0]
|
43 |
#Get the words from predicted word ids
|
44 |
+
transcription = tokenizer.decode(predicted_ids)
|
45 |
#Correcting the letter casing
|
46 |
transcription = correct_casing(transcription.lower())
|
47 |
return transcription
|
48 |
|
49 |
+
gr.Interface(asr_transcript, inputs = gr.Audio(source="microphone", type="filepath", optional=True, label="Speaker"), outputs = gr.Textbox(label="Output Text"), title="ASR For Hausa", description = "This application displays transcribed text for given audio input").launch()
|