pratikshahp commited on
Commit
a006d14
·
verified ·
1 Parent(s): c520d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -2,13 +2,17 @@ import torch
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import streamlit as st
4
  from audio_recorder_streamlit import audio_recorder
 
5
 
6
  # Function to transcribe audio to text
7
  def transcribe_audio(audio_bytes):
8
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
9
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
10
 
11
- input_values = processor(audio_bytes, return_tensors="pt", sampling_rate=16000).input_values
 
 
 
12
  logits = model(input_values).logits
13
  predicted_ids = torch.argmax(logits, dim=-1)
14
  transcription = processor.decode(predicted_ids[0])
 
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import streamlit as st
4
  from audio_recorder_streamlit import audio_recorder
5
+ import numpy as np
6
 
7
  # Function to transcribe audio to text
8
  def transcribe_audio(audio_bytes):
9
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
10
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
11
 
12
+ # Convert bytes to numpy array
13
+ audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
14
+
15
+ input_values = processor(audio_array, return_tensors="pt", sampling_rate=16000).input_values
16
  logits = model(input_values).logits
17
  predicted_ids = torch.argmax(logits, dim=-1)
18
  transcription = processor.decode(predicted_ids[0])