moro23 commited on
Commit
deeb5cc
·
1 Parent(s): e406f4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -7,12 +7,12 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
  nltk.download("punkt")
8
 
9
  ##
10
- token_value = "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"
11
  #Loading the pre-trained model and the tokenizer
12
- model_name = "moro23/wav2vec-large-xls-r-300-ha-colab_4"
13
  #tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name, use_auth_token=token_value)
14
- tokenizer = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token_value)
15
- model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token_value)
16
 
17
  def load_data(input_file):
18
 
@@ -23,7 +23,7 @@ def load_data(input_file):
23
  speech = speech[:,0] + speech[:,1]
24
  #Resampling the audio at 16KHz
25
  if sample_rate !=16000:
26
- speech = librosa.resample(speech, sample_rate,16000)
27
  return speech
28
 
29
  def correct_casing(input_sentence):
@@ -35,15 +35,15 @@ def asr_transcript(input_file):
35
 
36
  speech = load_data(input_file)
37
  #Tokenize
38
- input_values = tokenizer(speech, return_tensors="pt").input_values
39
  #Take logits
40
- logits = model(input_values).logits
41
  #Take argmax
42
- predicted_ids = torch.argmax(logits, dim=-1)
43
  #Get the words from predicted word ids
44
- transcription = tokenizer.decode(predicted_ids[0])
45
  #Correcting the letter casing
46
  transcription = correct_casing(transcription.lower())
47
  return transcription
48
 
49
- gr.Interface(asr_transcript, inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"), outputs = gr.outputs.Textbox(label="Output Text"), title="ASR For Hausa", description = "This application displays transcribed text for given audio input", theme="grass").launch(share=True)
 
7
  nltk.download("punkt")
8
 
9
  ##
10
+ #token_value = "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"
11
  #Loading the pre-trained model and the tokenizer
12
+ #model_name = "moro23/wav2vec-large-xls-r-300-ha-colab_4"
13
  #tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name, use_auth_token=token_value)
14
+ #tokenizer = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token_value)
15
+ #model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token_value)
16
 
17
  def load_data(input_file):
18
 
 
23
  speech = speech[:,0] + speech[:,1]
24
  #Resampling the audio at 16KHz
25
  if sample_rate !=16000:
26
+ speech = librosa.resample(speech, sample_rate, 16000)
27
  return speech
28
 
29
  def correct_casing(input_sentence):
 
35
 
36
  speech = load_data(input_file)
37
  #Tokenize
38
+ input_dict = tokenizer(speech, return_tensors="pt", padding=True)
39
  #Take logits
40
+ logits = model(input_dict.input_values.to("cuda").logits
41
  #Take argmax
42
+ predicted_ids = torch.argmax(logits, dim=-1)[0]
43
  #Get the words from predicted word ids
44
+ transcription = tokenizer.decode(predicted_ids)
45
  #Correcting the letter casing
46
  transcription = correct_casing(transcription.lower())
47
  return transcription
48
 
49
+ gr.Interface(asr_transcript, inputs = gr.Audio(source="microphone", type="filepath", optional=True, label="Speaker"), outputs = gr.Textbox(label="Output Text"), title="ASR For Hausa", description = "This application displays transcribed text for given audio input").launch()