unijoh commited on
Commit
ff3a5da
1 Parent(s): 7bce9ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -62
app.py CHANGED
@@ -1,75 +1,64 @@
1
  import gradio as gr
2
- import torchaudio
3
- import torch
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- from transformers import AutoProcessor, AutoModelForSeq2SeqLM
6
- from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
7
 
8
- # Load the models
9
- asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
10
- asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
11
 
12
- # Correct TTS model path
13
- tts_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mms-tts/models/fao")
14
- tts_processor = AutoProcessor.from_pretrained("facebook/mms-tts/models/fao")
15
 
16
- lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024")
17
- lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024")
18
 
19
- # ASR Function
20
- def asr_transcribe(audio):
21
- inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
22
- with torch.no_grad():
23
- logits = asr_model(**inputs).logits
24
- predicted_ids = torch.argmax(logits, dim=-1)
25
- transcription = asr_processor.batch_decode(predicted_ids)
26
- return transcription[0]
27
 
28
- # TTS Function
29
- def tts_synthesize(text):
30
- inputs = tts_processor(text, return_tensors="pt", padding=True)
31
- with torch.no_grad():
32
- generated_ids = tts_model.generate(**inputs)
33
- audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True)
34
- return audio[0]
35
 
36
- # Language ID Function
37
- def identify_language(audio):
38
- inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
39
- with torch.no_grad():
40
- logits = lid_model(**inputs).logits
41
- predicted_ids = torch.argmax(logits, dim=-1)
42
- language = lid_processor.batch_decode(predicted_ids)
43
- return language[0]
44
 
45
- # Clear Functions
46
- def clear_audio_input():
47
- return None
 
 
 
 
 
 
 
 
48
 
49
- def clear_text_input():
50
- return ""
 
 
 
 
 
 
 
 
 
51
 
52
- # Define the Gradio interfaces
53
- with gr.Blocks() as demo:
54
- with gr.Tab("ASR"):
55
- gr.Markdown("## Automatic Speech Recognition (ASR)")
56
- audio_input = gr.Audio(source="microphone", type="numpy")
57
- text_output = gr.Textbox(label="Transcription")
58
- gr.Button("Clear", fn=clear_audio_input, inputs=[], outputs=audio_input)
59
- gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output)
60
-
61
- with gr.Tab("TTS"):
62
- gr.Markdown("## Text-to-Speech (TTS)")
63
- text_input = gr.Textbox(label="Text")
64
- audio_output = gr.Audio(label="Audio Output")
65
- gr.Button("Clear", fn=clear_text_input, inputs=[], outputs=text_input)
66
- gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output)
67
 
68
- with gr.Tab("Language ID"):
69
- gr.Markdown("## Language Identification (LangID)")
70
- audio_input = gr.Audio(source="microphone", type="numpy")
71
- language_output = gr.Textbox(label="Identified Language")
72
- gr.Button("Clear", fn=clear_audio_input, inputs=[], outputs=audio_input)
73
- gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output)
74
 
75
  demo.launch()
 
1
  import gradio as gr
2
+ import librosa
3
+ from asr import transcribe
4
+ from tts import synthesize
 
 
5
 
6
+ def identify(microphone, file_upload):
7
+ LID_SAMPLING_RATE = 16_000
 
8
 
9
+ if (microphone is not None) and (file_upload is not None):
10
+ return "WARNING: Using microphone input. Uploaded file will be ignored."
 
11
 
12
+ if (microphone is None) and (file_upload is None):
13
+ return "ERROR: Provide an audio file or use the microphone."
14
 
15
+ audio_fp = microphone if microphone is not None else file_upload
16
+ inputs = librosa.load(audio_fp, sr=LID_SAMPLING_RATE, mono=True)[0]
 
 
 
 
 
 
17
 
18
+ return {"Faroese": 1.0}
 
 
 
 
 
 
19
 
20
+ demo = gr.Blocks()
 
 
 
 
 
 
 
21
 
22
+ mms_transcribe = gr.Interface(
23
+ fn=transcribe,
24
+ inputs=[
25
+ gr.Audio(source="microphone", type="filepath"),
26
+ gr.Audio(source="upload", type="filepath"),
27
+ ],
28
+ outputs="text",
29
+ title="Speech-to-text",
30
+ description="Transcribe audio!",
31
+ allow_flagging="never",
32
+ )
33
 
34
+ mms_synthesize = gr.Interface(
35
+ fn=synthesize,
36
+ inputs=[
37
+ gr.Text(label="Input text"),
38
+ gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Speed"),
39
+ ],
40
+ outputs=gr.Audio(label="Generated Audio", type="numpy"),
41
+ title="Text-to-speech",
42
+ description="Generate audio!",
43
+ allow_flagging="never",
44
+ )
45
 
46
+ mms_identify = gr.Interface(
47
+ fn=identify,
48
+ inputs=[
49
+ gr.Audio(source="microphone", type="filepath"),
50
+ gr.Audio(source="upload", type="filepath"),
51
+ ],
52
+ outputs=gr.Label(num_top_classes=1),
53
+ title="Language Identification",
54
+ description="Identify the language of audio!",
55
+ allow_flagging="never",
56
+ )
 
 
 
 
57
 
58
+ with demo:
59
+ gr.TabbedInterface(
60
+ [mms_synthesize, mms_transcribe, mms_identify],
61
+ ["Text-to-speech", "Speech-to-text", "Language Identification"],
62
+ )
 
63
 
64
  demo.launch()