File size: 2,924 Bytes
c106aba
 
b5b3a67
c106aba
38368a5
6b8cfb9
ea8b34f
 
 
 
 
 
296f52e
ea8b34f
c106aba
38368a5
6b8cfb9
38368a5
 
 
 
c106aba
06628a1
1e9cc4e
 
4f9fa63
346d452
4f9fa63
 
a1f28e4
6a4c9da
2ce792c
31c8a67
2ce792c
4f9fa63
 
 
8b8bd52
6c461f0
 
ea8b34f
8b8bd52
 
 
4f9fa63
c45c6ec
 
 
 
 
 
4f9fa63
8b8bd52
3d18c9e
 
 
c106aba
 
 
 
 
4f9fa63
 
 
 
ea8b34f
 
c106aba
0b02e4c
0de7587
0b02e4c
 
 
 
38368a5
4f9fa63
 
90f5dcf
0b02e4c
06628a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
import gradio as gr
import scipy.signal as sps
import sox

def convert(inputfile, outfile):
    sox_tfm = sox.Transformer()
    sox_tfm.set_output_format(
        file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
    )
    #print(this is not done)
    sox_tfm.build(inputfile, outfile)

def read_file(wav):
    sample_rate, signal = wav                                                                                                                        
    signal = signal.mean(-1)                                                                                                                              
    number_of_samples = round(len(signal) * float(16000) / sample_rate)                                                                                   
    resampled_signal = sps.resample(signal, number_of_samples)
    return resampled_signal




def parse_transcription_with_lm(wav_file):
    input_values = read_file_and_process(wav_file)

    with torch.no_grad():
        logits = model(**input_values).logits[0].cpu().numpy()
    print(logits)
    int_result = processor_with_LM.decode(logits = logits, output_word_offsets=False)
    print(int_result)
    transcription =  int_result.text.replace('<s>','')
    return transcription


def read_file_and_process(wav_file):
    filename = wav_file.split('.')[0]
    convert(wav_file, filename + "16k.wav")
    speech, _ = sf.read(filename + "16k.wav")
    inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
    
    return inputs

def parse(wav_file, applyLM):
    if applyLM:
        return parse_transcription_with_lm(wav_file)
    else:
        return parse_transcription(wav_file)

def parse_transcription(wav_file):
    input_values = read_file_and_process(wav_file)
    with torch.no_grad():
        logits = model(**input_values).logits
    #logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)

    transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
    return transcription
    
model_id = "Harveenchadha/vakyansh-wav2vec2-hindi-him-4200"
processor = Wav2Vec2Processor.from_pretrained(model_id)
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
    

    
input_ = gr.Audio(source="microphone", type="filepath") 
#input_ = gr.inputs.Audio(source="microphone", type="numpy") 
txtbox = gr.Textbox(
            label="Output from model will appear here:",
            lines=5
        )

chkbox = gr.Checkbox(label="Apply LM", value=False)

gr.Interface(parse, inputs = [input_, chkbox],  outputs=txtbox,
             streaming=True, interactive=True,
             analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);