Spaces:
Runtime error
Runtime error
File size: 3,615 Bytes
da0005f 20d39bb da0005f dc3ecb8 21907eb 20d39bb da0005f cab3a0b da0005f 20d39bb da0005f b79461c 20d39bb cab3a0b da0005f 20d39bb da0005f cab3a0b da0005f 20d39bb da0005f cab3a0b 20d39bb da0005f cab3a0b da0005f 17ded20 da0005f 17ded20 da0005f b79461c da0005f 17ded20 da0005f 17ded20 da0005f 59cfd1e da0005f cab3a0b da0005f 17ded20 20d39bb da0005f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
#Importing all the necessary packages
import nltk
import librosa
import IPython.display
import torch
import gradio as gr
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
nltk.download("punkt")
# In[ ]:
#Loading the model and the tokenizer
model_name = "facebook/wav2vec2-base-960h"
#model_name = "facebook/wav2vec2-large-xlsr-53"
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# In[ ]:
def load_data(input_file):
""" Function for resampling to ensure that the speech input is sampled at 16KHz.
"""
#read the file
speech, sample_rate = librosa.load(input_file)
#make it 1-D
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
#Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
#speeches = librosa.effects.split(speech)
return speech
# In[ ]:
def correct_casing(input_sentence):
""" This function is for correcting the casing of the generated transcribed text
"""
sentences = nltk.sent_tokenize(input_sentence)
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
# In[ ]:
def asr_transcript(input_file):
"""This function generates transcripts for the provided audio input
"""
speech = load_data(input_file)
#Tokenize
input_values = tokenizer(speech, return_tensors="pt").input_values
#Take logits
logits = model(input_values).logits
#Take argmax
predicted_ids = torch.argmax(logits, dim=-1)
#Get the words from predicted word ids
transcription = tokenizer.decode(predicted_ids[0])
#Output is all upper case
transcription = correct_casing(transcription.lower())
return transcription
# In[ ]:
def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
transcript = ""
# Ensure that the sample rate is 16k
sample_rate = librosa.get_samplerate(input_file)
# Stream over 10 seconds chunks rather than load the full file
stream = librosa.stream(
input_file,
block_length=20, #number of seconds to split the batch
frame_length=sample_rate, #16000,
hop_length=sample_rate, #16000
)
for speech in stream:
if len(speech.shape) > 1:
speech = speech[:, 0] + speech[:, 1]
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
input_values = tokenizer(speech, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
#transcript += transcription.lower()
transcript += correct_casing(transcription.lower())
#transcript += " "
return transcript[:3800]
# In[ ]:
gr.Interface(asr_transcript_long,
#inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your file here"),
outputs = gr.outputs.Textbox(type="str",label="Output Text"),
title="Transcript and Translate",
description = "This application displays transcribed text for given audio input",
examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
|