siddh4rth commited on
Commit
6919ad7
·
1 Parent(s): 1fd29df

pretrained model

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import whisper
4
  import librosa
5
  import torch
6
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
@@ -14,18 +14,28 @@ def audio_to_text(audio):
14
  result = model.transcribe(audio)
15
 
16
  return result["text"]
17
-
18
- # audio, rate = librosa.load(audio, sr = 16000)
19
-
20
  # tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
21
- # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
22
-
23
- # input_values = tokenizer(audio, return_tensors="pt").input_values.to(device)
24
- # logits = model(input_values).logits
25
 
26
- # prediction = torch.argmax(logits, dim=-1)
27
- # transcription = tokenizer.batch_decode(prediction)[0]
28
- return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  demo = gr.Interface(
31
  fn=audio_to_text,
 
3
  import whisper
4
  import librosa
5
  import torch
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2Tokenizer
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
14
  result = model.transcribe(audio)
15
 
16
  return result["text"]
 
 
 
17
  # tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
 
 
 
 
18
 
19
+ # logits = preprocess(audio)
20
+
21
+ # predicted_ids = torch.argmax(logits, dim=-1)
22
+ # transcriptions = tokenizer.decode(predicted_ids[0])
23
+ # return transcriptions
24
+
25
+ def preprocess(audio):
26
+ model_save_path = "model_save"
27
+ model_name = "wav2vec2_osr_version_1"
28
+ speech, rate = librosa.load(audio, sr=16000)
29
+ model_path = os.path.join(model_save_path, model_name+".pt")
30
+ pipeline_path = os.path.join(model_save_path, model_name+"_vocab")
31
+
32
+ access_token = "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"
33
+ processor = Wav2Vec2Processor.from_pretrained(pipeline_path, use_auth_token=access_token)
34
+ model = torch.load(model_path)
35
+ model.eval()
36
+ input_values = processor(speech, sampling_rate=rate, return_tensors="pt").input_values.to(device)
37
+ logits = model(input_values).logits
38
+ return logits
39
 
40
  demo = gr.Interface(
41
  fn=audio_to_text,