Futuresony commited on
Commit
ea905bc
·
verified ·
1 Parent(s): 832536b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import soundfile as sf
4
+ import torch
5
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
6
+
7
+ # Assuming 'transcribe' was defined in a previous cell.
8
+ # If not, define it here or import it from the correct module.
9
+
10
+ # Create a placeholder for ASR_LANGUAGES if it's not defined elsewhere.
11
+ ASR_LANGUAGES = {"eng": "English", "swh": "Swahili"} # Replace with your actual languages
12
+
13
+ # ✅ Define or Re-define the `transcribe` function within this cell
14
+ MODEL_ID = "facebook/mms-1b-all" # Make sure this is the same model ID used for training
15
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
16
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
17
+
18
+ def transcribe(audio_path, language):
19
+ """Transcribes an audio file using the fine-tuned model."""
20
+ # Set the target language based on user selection
21
+ if language:
22
+ target_lang = language.split(" ")[0] # Extract language code
23
+ processor.tokenizer.set_target_lang(target_lang)
24
+ if target_lang != "eng": # Load adapter if not English
25
+ model.load_adapter(target_lang)
26
+
27
+ audio, samplerate = sf.read(audio_path)
28
+ inputs = processor(audio, sampling_rate=samplerate, return_tensors="pt")
29
+
30
+ with torch.no_grad():
31
+ outputs = model(**inputs).logits
32
+
33
+ ids = torch.argmax(outputs, dim=-1)[0]
34
+ return processor.decode(ids)
35
+
36
+
37
+ mms_transcribe = gr.Interface(
38
+ fn=transcribe,
39
+ inputs=[
40
+ gr.Audio(),
41
+ gr.Dropdown(
42
+ [f"{k} ({v})" for k, v in ASR_LANGUAGES.items()],
43
+ label="Language",
44
+ value="eng English",
45
+ ),
46
+ ],
47
+ outputs="text",
48
+ title="Speech-to-Text Transcription",
49
+ description="Transcribe audio input into text.",
50
+ allow_flagging="never",
51
+ )
52
+
53
+ with gr.Blocks() as demo:
54
+ mms_transcribe.render()
55
+
56
+ if __name__ == "__main__":
57
+ demo.queue()
58
+ demo.launch()