Harveenchadha
commited on
Commit
•
818190c
1
Parent(s):
a48ce15
Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,17 @@
|
|
1 |
import soundfile as sf
|
2 |
import torch
|
3 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
4 |
-
import argparse
|
5 |
-
from glob import glob
|
6 |
-
import subprocess
|
7 |
import gradio as gr
|
8 |
|
9 |
|
10 |
-
def get_filename(wav_file):
|
11 |
-
filename_local = wav_file.split('/')[-1][:-4]
|
12 |
-
filename_new = '/tmp/'+filename_local+'_16.wav'
|
13 |
-
|
14 |
-
|
15 |
-
subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(wav_file, str(16000), filename_new)], shell=True)
|
16 |
-
return filename_new
|
17 |
-
|
18 |
-
|
19 |
|
20 |
def parse_transcription(wav_file):
|
21 |
-
# load pretrained model
|
22 |
-
|
23 |
-
# load audio
|
24 |
-
|
25 |
-
|
26 |
-
#wav_file = get_filename(wav_file.name)
|
27 |
audio_input, sample_rate = sf.read(wav_file.name)
|
28 |
-
#test_file = resampler(test_file[0])
|
29 |
-
|
30 |
-
# pad input values and return pt tensor
|
31 |
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
32 |
|
33 |
-
# INFERENCE
|
34 |
-
# retrieve logits & take argmax
|
35 |
logits = model(input_values).logits
|
36 |
predicted_ids = torch.argmax(logits, dim=-1)
|
37 |
|
38 |
-
# transcribe
|
39 |
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
|
40 |
return transcription
|
41 |
|
|
|
1 |
import soundfile as sf
|
2 |
import torch
|
3 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
|
|
|
|
|
4 |
import gradio as gr
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
def parse_transcription(wav_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
audio_input, sample_rate = sf.read(wav_file.name)
|
|
|
|
|
|
|
10 |
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
11 |
|
|
|
|
|
12 |
logits = model(input_values).logits
|
13 |
predicted_ids = torch.argmax(logits, dim=-1)
|
14 |
|
|
|
15 |
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
|
16 |
return transcription
|
17 |
|