amritsar commited on
Commit
dd6458a
·
verified ·
1 Parent(s): 0c9ec44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -50
app.py CHANGED
@@ -1,51 +1,19 @@
1
- import gradio as gr
2
- from transformers import Wav2Vec2Processor, Wav2Vec2BertForCTC
3
  import torch
4
- import librosa
5
- import numpy as np
6
-
7
- # Load the correct processor and model
8
- model_id = "kdcyberdude/w2v-bert-punjabi"
9
- processor = Wav2Vec2Processor.from_pretrained(model_id)
10
- model = Wav2Vec2BertForCTC.from_pretrained(model_id)
11
-
12
- def transcribe_audio(audio_file):
13
- try:
14
- # Load and preprocess the audio
15
- audio, rate = librosa.load(audio_file, sr=16000) # Resample to 16 kHz
16
- if len(audio.shape) > 1: # If stereo, convert to mono
17
- audio = np.mean(audio, axis=1)
18
-
19
- # Normalize audio to match expected input range [-1, 1]
20
- audio = librosa.util.normalize(audio)
21
-
22
- # Split into manageable chunks (30 seconds each)
23
- chunk_size = int(30 * rate) # 30 seconds in samples
24
- transcription = []
25
-
26
- for i in range(0, len(audio), chunk_size):
27
- chunk = audio[i:i + chunk_size]
28
- input_values = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values
29
-
30
- # Perform inference
31
- with torch.no_grad():
32
- logits = model(input_values).logits
33
-
34
- # Decode predicted IDs to text
35
- predicted_ids = torch.argmax(logits, dim=-1)
36
- transcription.append(processor.batch_decode(predicted_ids)[0])
37
-
38
- return " ".join(transcription)
39
- except Exception as e:
40
- return f"Error: {str(e)}"
41
-
42
- # Gradio interface setup
43
- iface = gr.Interface(
44
- fn=transcribe_audio,
45
- inputs=gr.Audio(type="filepath"),
46
- outputs=gr.Textbox(label="Punjabi Transcription"),
47
- title="Punjabi Audio Transcription",
48
- description="Upload an audio file to transcribe Punjabi speech."
49
- )
50
-
51
- iface.launch()
 
1
+ import soundfile as sf
 
2
  import torch
3
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
+ import argparse
5
+ def parse_transcription(wav_file):
6
+ # load pretrained model
7
+ processor = Wav2Vec2Processor.from_pretrained("addy88/wav2vec2-punjabi-stt")
8
+ model = Wav2Vec2ForCTC.from_pretrained("addy88/wav2vec2-punjabi-stt")
9
+ # load audio
10
+ audio_input, sample_rate = sf.read(wav_file)
11
+ # pad input values and return pt tensor
12
+ input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
13
+ # INFERENCE
14
+ # retrieve logits & take argmax
15
+ logits = model(input_values).logits
16
+ predicted_ids = torch.argmax(logits, dim=-1)
17
+ # transcribe
18
+ transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
19
+ print(transcription)