Futuresony's picture
Update asr.py
0b5f65d verified
raw
history blame
2.35 kB
import librosa
import torch
import numpy as np
import langid
from transformers import Wav2Vec2ForCTC, AutoProcessor
ASR_SAMPLING_RATE = 16_000
MODEL_ID = "facebook/mms-1b-all" # Or your model ID
# Load MMS Model (outside the function, for efficiency)
try:
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()
except Exception as e:
print(f"Error loading initial model: {e}") # Handle initial model loading errors
exit(1) # Or raise the exception if you prefer
def detect_language(text):
lang, _ = langid.classify(text)
return lang if lang in ["en", "sw"] else "en"
def transcribe_auto(audio_data=None):
if not audio_data:
return "<<ERROR: Empty Audio Input>>"
# ... (audio processing code remains the same) ...
try: # Wrap the entire transcription process
# **Step 1: Transcribe without Language Detection**
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
raw_transcription = processor.decode(ids)
# **Step 2: Detect Language from Transcription**
detected_lang = detect_language(raw_transcription)
lang_code = "eng" if detected_lang == "en" else "swh"
# **Step 3: Reload Model with Correct Adapter (CRITICAL CHANGE)**
try: # Wrap adapter loading
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code) # This is the most likely source of errors
except Exception as adapter_error: # Catch adapter loading errors
print(f"Error loading adapter for {detected_lang}: {adapter_error}")
return f"<<ERROR: Could not load adapter for {detected_lang}>>" # Or raise
# **Step 4: Transcribe Again with Correct Adapter**
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
final_transcription = processor.decode(ids)
return f"Detected Language: {detected_lang.upper()}\n\nTranscription:\n{final_transcription}"
except Exception as overall_error: # Catch any other errors during transcription
print(f"An error occurred during transcription: {overall_error}")
return f"<<ERROR: {overall_error}>>"