pratikshahp commited on
Commit
c520d9a
·
verified ·
1 Parent(s): a1c0ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -1,30 +1,34 @@
1
  import torch
2
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
3
  import streamlit as st
4
  from audio_recorder_streamlit import audio_recorder
5
 
6
- audio_bytes = audio_recorder(pause_threshold=3.0, sample_rate=16_000)
7
-
8
- if audio_bytes:
9
- st.audio(audio_bytes, format="audio/wav")
10
-
11
- # Load pre-trained model and tokenizer
12
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
13
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
14
 
15
- # Tokenize the audio input
16
- input_values = tokenizer(audio_bytes, return_tensors='pt').input_values
17
-
18
- # Perform inference
19
  logits = model(input_values).logits
20
  predicted_ids = torch.argmax(logits, dim=-1)
 
 
 
21
 
22
- # Decode the audio to generate text
23
- transcriptions = tokenizer.decode(predicted_ids[0])
 
 
 
 
 
 
 
24
 
25
- if transcriptions is not None:
26
- st.write(transcriptions)
 
27
  else:
28
- st.write("Error: Failed to decode audio.")
29
  else:
30
  st.write("No audio recorded.")
 
1
  import torch
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import streamlit as st
4
  from audio_recorder_streamlit import audio_recorder
5
 
6
+ # Function to transcribe audio to text
7
+ def transcribe_audio(audio_bytes):
8
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
 
 
 
 
9
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
10
 
11
+ input_values = processor(audio_bytes, return_tensors="pt", sampling_rate=16000).input_values
 
 
 
12
  logits = model(input_values).logits
13
  predicted_ids = torch.argmax(logits, dim=-1)
14
+ transcription = processor.decode(predicted_ids[0])
15
+
16
+ return transcription
17
 
18
+ # Streamlit app
19
+ st.title("Audio to Text Transcription")
20
+
21
+ audio_bytes = audio_recorder(pause_threshold=3.0, sample_rate=16_000)
22
+
23
+ if audio_bytes:
24
+ st.audio(audio_bytes, format="audio/wav")
25
+
26
+ transcription = transcribe_audio(audio_bytes)
27
 
28
+ if transcription:
29
+ st.write("Transcription:")
30
+ st.write(transcription)
31
  else:
32
+ st.write("Error: Failed to transcribe audio.")
33
  else:
34
  st.write("No audio recorded.")