File size: 2,330 Bytes
8fde49e
 
98939b7
8fde49e
 
 
c634e64
fea3916
 
6875818
1da4128
fea3916
2c467e4
26a3f0d
3eee1de
 
fea3916
 
 
ebb3ff5
fea3916
 
 
 
 
deeb5cc
fea3916
 
 
 
 
 
 
 
 
 
 
daabb0b
fea3916
3d9afa3
900aeb7
fea3916
deeb5cc
fea3916
deeb5cc
fea3916
 
723ece3
fea3916
 
12713e2
fa8c187
010bbd1
12713e2
 
 
 
de12b8d
12713e2
 
 
 
 
3cae08e
 
f081cd3
3cae08e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#Importing all the necessary packages
import nltk
import soundfile 
import librosa
import torch
import gradio as gr
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
nltk.download("punkt")

## 
token_value = "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"
#Loading the pre-trained model and the tokenizer
model_name = "moro23/wav2vec-large-xls-r-300-ha-colab_4"
#tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name, use_auth_token=token_value)
tokenizer = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token_value)
model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token_value)

def load_data(input_file):

  speech , sample_rate = librosa.load(input_file)
  #make it 1-D
  if len(speech.shape) > 1: 
      speech = speech[:,0] + speech[:,1]
  #Resampling the audio at 16KHz
  if sample_rate !=16000:
    speech = librosa.resample(speech, sample_rate, 16000)
  return speech

def correct_casing(input_sentence):

  sentences = nltk.sent_tokenize(input_sentence)
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
  
def asr_transcript(input_file):

  speech = load_data(input_file)
  #Tokenize
  input_dict = tokenizer(speech, return_tensors="pt", sampling_rate=16000, padding=True)
  #Take logits
  logits = model(input_dict.input_values).logits
  
  #Take argmax
  predicted_ids = torch.argmax(logits, dim=-1)[0]
  #Get the words from predicted word ids
  transcription = tokenizer.decode(predicted_ids)
  #Correcting the letter casing
  transcription = correct_casing(transcription.lower())
  
  return transcription
  
################### Gradio Web APP ################################ 
hf_writer = gr.HuggingFaceDatasetSaver(token_value, "Hausa-ASR-flags")

title = "Hausa Automatic Speech Recognition"
 
examples = [["Sample/sample1.mp3"], ["Sample/sample2.mp3"], ["Sample/sample3.mp3"]]
 
Input = gr.Audio(source="microphone", type="filepath", label="Please Record Your Voice")

Output = gr.Textbox(label="Hausa Script")

description = "This application displays transcribed text for given audio input"

demo = gr.Interface(fn = asr_transcript, inputs = Input, outputs = Output, title = title, flagging_options=["incorrect", "worst", "ambiguous"],
        allow_flagging="manual",flagging_callback=hf_writer,description= description)

demo.launch(share=True)