import torch import spaces import whisper import subprocess import numpy as np import gradio as gr import soundfile as sf import torchaudio as ta from model_utils import get_processor, get_model, get_whisper_model_small, get_device from config import SAMPLING_RATE, CHUNK_LENGTH_S # def resample_with_ffmpeg(input_file, output_file, target_sr=16000): # command = [ # 'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file # ] # subprocess.run(command, check=True) @spaces.GPU def load_and_resample_audio(file): try: # First attempt: Use torchaudio.load() waveform, sample_rate = torchaudio.load(file) except Exception as e: print(f"torchaudio.load() failed: {e}") try: # Second attempt: Use soundfile waveform, sample_rate = sf.read(file) waveform = torch.from_numpy(waveform.T).float() if waveform.dim() == 1: waveform = waveform.unsqueeze(0) except Exception as e: print(f"soundfile.read() failed: {e}") raise ValueError(f"Failed to load audio file: {file}") print(f"Original audio shape: {waveform.shape}, Sample rate: {sample_rate}") if sample_rate != SAMPLING_RATE: try: waveform = F.resample(waveform, sample_rate, SAMPLING_RATE) except Exception as e: print(f"Resampling failed: {e}") raise ValueError(f"Failed to resample audio from {sample_rate} to {SAMPLING_RATE}") # Ensure the audio is in the correct shape (mono) if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) print(f"Processed audio shape: {waveform.shape}, New sample rate: {SAMPLING_RATE}") return waveform, SAMPLING_RATE @spaces.GPU def detect_language(audio): whisper_model = get_whisper_model_small() # Save the input audio to a temporary file ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0]) # Resample if necessary using ffmpeg if audio[0] != SAMPLING_RATE: resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE) audio_tensor, _ = ta.load("resampled_audio.wav") else: audio_tensor = torch.tensor(audio[1]).float() # Ensure the audio is in the correct shape (mono) if audio_tensor.dim() == 2: audio_tensor = audio_tensor.mean(dim=0) # Use Whisper's preprocessing audio_tensor = whisper.pad_or_trim(audio_tensor) print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds") mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device) # Detect language _, probs = whisper_model.detect_language(mel) detected_lang = max(probs, key=probs.get) print(f"Audio shape: {audio_tensor.shape}") print(f"Mel spectrogram shape: {mel.shape}") print(f"Detected language: {detected_lang}") print("Language probabilities:", probs) return detected_lang @spaces.GPU def process_long_audio(audio, task="transcribe", language=None): if audio[0] != SAMPLING_RATE: # Save the input audio to a file for ffmpeg processing ta.save("input_audio_1.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0]) # Resample using ffmpeg try: resample_with_ffmpeg("input_audio_1.wav", "resampled_audio_2.wav", target_sr=SAMPLING_RATE) except subprocess.CalledProcessError as e: print(f"ffmpeg failed: {e.stderr}") raise e waveform, _ = ta.load("resampled_audio_2.wav") else: waveform = torch.tensor(audio[1]).float() # Ensure the audio is in the correct shape (mono) if waveform.dim() == 2: waveform = waveform.mean(dim=0) print(f"Waveform shape after processing: {waveform.shape}") if waveform.numel() == 0: raise ValueError("Waveform is empty. Please check the input audio file.") input_length = waveform.shape[0] # Since waveform is 1D, access the length with shape[0] chunk_length = int(CHUNK_LENGTH_S * SAMPLING_RATE) # Corrected slicing for 1D tensor chunks = [waveform[i:i + chunk_length] for i in range(0, input_length, chunk_length)] # Initialize the processor processor = get_processor() model = get_model() device = get_device() results = [] for chunk in chunks: input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device) with torch.no_grad(): if task == "translate": forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate") generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) else: generated_ids = model.generate(input_features) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) results.extend(transcription) # Clear GPU cache torch.cuda.empty_cache() return " ".join(results) @spaces.GPU def process_audio(audio): if audio is None: return "No file uploaded", "", "" detected_lang = detect_language(audio) transcription = process_long_audio(audio, task="transcribe") translation = process_long_audio(audio, task="translate", language=detected_lang) return detected_lang, transcription, translation # Gradio interface iface = gr.Interface( fn=process_audio, inputs=gr.Audio(), outputs=[ gr.Textbox(label="Detected Language"), gr.Textbox(label="Transcription", lines=5), gr.Textbox(label="Translation", lines=5) ], title="Audio Transcription and Translation", description="Upload an audio file to detect its language, transcribe, and translate it.", allow_flagging="never", css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }" )