unijoh commited on
Commit
9aa67b4
1 Parent(s): 3291f24

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +5 -9
asr.py CHANGED
@@ -3,23 +3,19 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
 
5
  ASR_SAMPLING_RATE = 16_000
 
6
  MODEL_ID = "facebook/mms-1b-all"
7
 
8
  processor = AutoProcessor.from_pretrained(MODEL_ID)
9
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
10
 
11
- def transcribe(audio_source=None, microphone=None, file_upload=None):
12
- audio_fp = file_upload if "upload" in str(audio_source or "").lower() else microphone
13
- if audio_fp is None:
14
  return "ERROR: You have to either use the microphone or upload an audio file"
15
 
16
- audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
17
- processor.tokenizer.set_target_lang("fao") # Set Faroese language
18
- model.load_adapter("fao")
19
-
20
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
21
 
22
- # Set device
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model.to(device)
25
  inputs = inputs.to(device)
@@ -29,5 +25,5 @@ def transcribe(audio_source=None, microphone=None, file_upload=None):
29
 
30
  ids = torch.argmax(outputs, dim=-1)[0]
31
  transcription = processor.decode(ids)
32
-
33
  return transcription
 
3
  import torch
4
 
5
  ASR_SAMPLING_RATE = 16_000
6
+
7
  MODEL_ID = "facebook/mms-1b-all"
8
 
9
  processor = AutoProcessor.from_pretrained(MODEL_ID)
10
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
11
 
12
+ def transcribe(audio=None):
13
+ if audio is None:
 
14
  return "ERROR: You have to either use the microphone or upload an audio file"
15
 
16
+ audio_samples = librosa.load(audio, sr=ASR_SAMPLING_RATE, mono=True)[0]
 
 
 
17
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
18
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
  inputs = inputs.to(device)
 
25
 
26
  ids = torch.argmax(outputs, dim=-1)[0]
27
  transcription = processor.decode(ids)
28
+
29
  return transcription