unijoh commited on
Commit
cfbbfed
1 Parent(s): 322575d

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +12 -11
asr.py CHANGED
@@ -3,16 +3,22 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
 
5
  ASR_SAMPLING_RATE = 16_000
6
- MODEL_ID = "facebook/mms-1b-all"
7
 
 
8
  processor = AutoProcessor.from_pretrained(MODEL_ID)
9
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
10
 
11
- def transcribe(audio):
12
- if audio is None:
 
13
  return "ERROR: You have to either use the microphone or upload an audio file"
14
 
15
- audio_samples = librosa.load(audio, sr=ASR_SAMPLING_RATE, mono=True)[0]
 
 
 
 
 
16
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -21,12 +27,7 @@ def transcribe(audio):
21
 
22
  with torch.no_grad():
23
  outputs = model(**inputs).logits
 
 
24
 
25
- # Setting the target language to Faroese (ISO 639-3: fao)
26
- processor.tokenizer.set_target_lang("fao")
27
- model.load_adapter("fao")
28
-
29
- ids = torch.argmax(outputs, dim=-1)[0]
30
- transcription = processor.decode(ids)
31
-
32
  return transcription
 
3
  import torch
4
 
5
  ASR_SAMPLING_RATE = 16_000
 
6
 
7
+ MODEL_ID = "facebook/mms-1b-all"
8
  processor = AutoProcessor.from_pretrained(MODEL_ID)
9
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
10
 
11
+ def transcribe(audio_source=None, microphone=None, file_upload=None):
12
+ audio_fp = file_upload if file_upload else microphone
13
+ if audio_fp is None:
14
  return "ERROR: You have to either use the microphone or upload an audio file"
15
 
16
+ audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
17
+
18
+ # Set Faroese language
19
+ processor.tokenizer.set_target_lang("fao")
20
+ model.load_adapter("fao")
21
+
22
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
27
 
28
  with torch.no_grad():
29
  outputs = model(**inputs).logits
30
+ ids = torch.argmax(outputs, dim=-1)[0]
31
+ transcription = processor.decode(ids)
32
 
 
 
 
 
 
 
 
33
  return transcription