import spaces import torch import gradio as gr import whisperx from transformers.pipelines.audio_utils import ffmpeg_read import tempfile import gc import os # Constants DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 4 # reduce if low on GPU mem COMPUTE_TYPE = "float32" # change to "int8" if low on GPU mem FILE_LIMIT_MB = 1000 @spaces.GPU(duration=200) def transcribe_audio(inputs, task): if inputs is None: raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") try: # Load audio if isinstance(inputs, str): # For file path input audio = whisperx.load_audio(inputs) else: # For microphone input (needs conversion) audio = whisperx.load_audio(inputs) # 1. Transcribe with base Whisper model model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE) result = model.transcribe(audio, batch_size=BATCH_SIZE) # Clear GPU memory del model gc.collect() torch.cuda.empty_cache() # 2. Align whisper output model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE) result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) # Clear GPU memory again del model_a gc.collect() torch.cuda.empty_cache() # 3. Diarize audio diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE) diarize_segments = diarize_model(audio) # 4. Assign speaker labels result = whisperx.assign_word_speakers(diarize_segments, result) # Format output output_text = "" for segment in result['segments']: speaker = segment.get('speaker', 'Unknown Speaker') text = segment['text'] output_text += f"{speaker}: {text}\n" return output_text except Exception as e: raise gr.Error(f"Error processing audio: {str(e)}") finally: # Final cleanup gc.collect() torch.cuda.empty_cache() # Create Gradio interface demo = gr.Blocks(theme=gr.themes.Ocean()) with demo: gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Audio Input (Microphone or File Upload)" ) task = gr.Radio( ["transcribe", "translate"], label="Task", value="transcribe" ) submit_button = gr.Button("Process Audio") with gr.Column(): output_text = gr.Textbox( label="Transcription with Speaker Diarization", lines=10, placeholder="Transcribed text will appear here..." ) gr.Markdown(""" ### Features: - High-accuracy transcription using WhisperX - Automatic speaker diarization - Support for both microphone recording and file upload - GPU-accelerated processing ### Note: Processing may take a few moments depending on the audio length and system resources. """) submit_button.click( fn=transcribe_audio, inputs=[audio_input, task], outputs=output_text ) demo.queue().launch(ssr_mode=False)