File size: 1,843 Bytes
ea905bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

import gradio as gr
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor

# Assuming 'transcribe' was defined in a previous cell.
# If not, define it here or import it from the correct module.

# Create a placeholder for ASR_LANGUAGES if it's not defined elsewhere.
ASR_LANGUAGES = {"eng": "English", "swh": "Swahili"}  # Replace with your actual languages

# ✅ Define or Re-define the `transcribe` function within this cell
MODEL_ID = "facebook/mms-1b-all" # Make sure this is the same model ID used for training
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)

def transcribe(audio_path, language):
    """Transcribes an audio file using the fine-tuned model."""
    # Set the target language based on user selection
    if language:
        target_lang = language.split(" ")[0]  # Extract language code
        processor.tokenizer.set_target_lang(target_lang)
        if target_lang != "eng":  # Load adapter if not English
            model.load_adapter(target_lang)

    audio, samplerate = sf.read(audio_path)
    inputs = processor(audio, sampling_rate=samplerate, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs).logits

    ids = torch.argmax(outputs, dim=-1)[0]
    return processor.decode(ids)


mms_transcribe = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.Audio(),
        gr.Dropdown(
            [f"{k} ({v})" for k, v in ASR_LANGUAGES.items()],
            label="Language",
            value="eng English",
        ),
    ],
    outputs="text",
    title="Speech-to-Text Transcription",
    description="Transcribe audio input into text.",
    allow_flagging="never",
)

with gr.Blocks() as demo:
    mms_transcribe.render()

if __name__ == "__main__":
    demo.queue()
    demo.launch()