File size: 3,601 Bytes
c106aba
 
b5b3a67
c106aba
38368a5
6b8cfb9
487d42a
ea8b34f
 
 
 
 
 
296f52e
ea8b34f
c106aba
38368a5
6b8cfb9
38368a5
 
 
 
c106aba
06628a1
144bafa
f2d1246
1e9cc4e
f2d1246
 
 
 
 
1e9cc4e
4f9fa63
346d452
4f9fa63
c043038
 
 
 
 
 
 
 
 
 
4f9fa63
23ba586
c043038
a958451
c043038
07d59d1
4f9fa63
 
 
8b8bd52
6c461f0
69a550c
ea8b34f
8b8bd52
 
 
4f9fa63
c45c6ec
 
 
 
 
 
4f9fa63
8b8bd52
3d18c9e
 
 
c106aba
 
 
 
 
4f9fa63
 
 
 
ea8b34f
 
c106aba
0b02e4c
0de7587
0b02e4c
 
 
 
38368a5
4f9fa63
 
90f5dcf
0b02e4c
06628a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
import gradio as gr
import scipy.signal as sps
import sox
import subprocess

def convert(inputfile, outfile):
    sox_tfm = sox.Transformer()
    sox_tfm.set_output_format(
        file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
    )
    #print(this is not done)
    sox_tfm.build(inputfile, outfile)

def read_file(wav):
    sample_rate, signal = wav                                                                                                                        
    signal = signal.mean(-1)                                                                                                                              
    number_of_samples = round(len(signal) * float(16000) / sample_rate)                                                                                   
    resampled_signal = sps.resample(signal, number_of_samples)
    return resampled_signal


def resampler(input_file_path, output_file_path):
    #output_file_path = output_folder_path + input_file_path.split('/')[-1]

    command = (
        f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
        f"{output_file_path}"
    )
    subprocess.call(command, shell=True)

def parse_transcription_with_lm(wav_file):
    input_values = read_file_and_process(wav_file)

    # with torch.no_grad():
    #     logits = model(**input_values).logits[0].cpu().numpy()
    # print(logits)
    # int_result = processor_with_LM.decode(logits = logits, output_word_offsets=False,
    #                                      beam_width=128
    #                                      )
    # print(int_result)
    # transcription =  int_result.text.replace('<s>','')


    with torch.no_grad():
        logits = model(**input_values).logits

    result = processor_with_LM.batch_decode(logits.cpu().numpy())
    text = result.text
    transcription = text[0].replace('<s>','')
    return transcription


def read_file_and_process(wav_file):
    filename = wav_file.split('.')[0]
    resampler(wav_file, filename + "16k.wav")
    speech, _ = sf.read(filename + "16k.wav")
    inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
    
    return inputs

def parse(wav_file, applyLM):
    if applyLM:
        return parse_transcription_with_lm(wav_file)
    else:
        return parse_transcription(wav_file)

def parse_transcription(wav_file):
    input_values = read_file_and_process(wav_file)
    with torch.no_grad():
        logits = model(**input_values).logits
    #logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)

    transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
    return transcription
    
model_id = "Harveenchadha/vakyansh-wav2vec2-hindi-him-4200"
processor = Wav2Vec2Processor.from_pretrained(model_id)
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
    

    
input_ = gr.Audio(source="microphone", type="filepath") 
#input_ = gr.inputs.Audio(source="microphone", type="numpy") 
txtbox = gr.Textbox(
            label="Output from model will appear here:",
            lines=5
        )

chkbox = gr.Checkbox(label="Apply LM", value=False)

gr.Interface(parse, inputs = [input_, chkbox],  outputs=txtbox,
             streaming=True, interactive=True,
             analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);