Spaces:
Build error
Build error
import torch | |
import spaces | |
import whisper | |
import subprocess | |
import numpy as np | |
import gradio as gr | |
import soundfile as sf | |
import torchaudio as ta | |
from model_utils import get_processor, get_model, get_whisper_model_small, get_device | |
from config import SAMPLING_RATE, CHUNK_LENGTH_S | |
# def resample_with_ffmpeg(input_file, output_file, target_sr=16000): | |
# command = [ | |
# 'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file | |
# ] | |
# subprocess.run(command, check=True) | |
def load_and_resample_audio(file): | |
try: | |
# First attempt: Use torchaudio.load() | |
waveform, sample_rate = torchaudio.load(file) | |
except Exception as e: | |
print(f"torchaudio.load() failed: {e}") | |
try: | |
# Second attempt: Use soundfile | |
waveform, sample_rate = sf.read(file) | |
waveform = torch.from_numpy(waveform.T).float() | |
if waveform.dim() == 1: | |
waveform = waveform.unsqueeze(0) | |
except Exception as e: | |
print(f"soundfile.read() failed: {e}") | |
raise ValueError(f"Failed to load audio file: {file}") | |
print(f"Original audio shape: {waveform.shape}, Sample rate: {sample_rate}") | |
if sample_rate != SAMPLING_RATE: | |
try: | |
waveform = F.resample(waveform, sample_rate, SAMPLING_RATE) | |
except Exception as e: | |
print(f"Resampling failed: {e}") | |
raise ValueError(f"Failed to resample audio from {sample_rate} to {SAMPLING_RATE}") | |
# Ensure the audio is in the correct shape (mono) | |
if waveform.dim() > 1 and waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
print(f"Processed audio shape: {waveform.shape}, New sample rate: {SAMPLING_RATE}") | |
return waveform, SAMPLING_RATE | |
def detect_language(audio): | |
whisper_model = get_whisper_model_small() | |
# Save the input audio to a temporary file | |
ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0]) | |
# Resample if necessary using ffmpeg | |
if audio[0] != SAMPLING_RATE: | |
resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE) | |
audio_tensor, _ = ta.load("resampled_audio.wav") | |
else: | |
audio_tensor = torch.tensor(audio[1]).float() | |
# Ensure the audio is in the correct shape (mono) | |
if audio_tensor.dim() == 2: | |
audio_tensor = audio_tensor.mean(dim=0) | |
# Use Whisper's preprocessing | |
audio_tensor = whisper.pad_or_trim(audio_tensor) | |
print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds") | |
mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device) | |
# Detect language | |
_, probs = whisper_model.detect_language(mel) | |
detected_lang = max(probs, key=probs.get) | |
print(f"Audio shape: {audio_tensor.shape}") | |
print(f"Mel spectrogram shape: {mel.shape}") | |
print(f"Detected language: {detected_lang}") | |
print("Language probabilities:", probs) | |
return detected_lang | |
def process_long_audio(audio, task="transcribe", language=None): | |
if audio[0] != SAMPLING_RATE: | |
# Save the input audio to a file for ffmpeg processing | |
ta.save("input_audio_1.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0]) | |
# Resample using ffmpeg | |
try: | |
resample_with_ffmpeg("input_audio_1.wav", "resampled_audio_2.wav", target_sr=SAMPLING_RATE) | |
except subprocess.CalledProcessError as e: | |
print(f"ffmpeg failed: {e.stderr}") | |
raise e | |
waveform, _ = ta.load("resampled_audio_2.wav") | |
else: | |
waveform = torch.tensor(audio[1]).float() | |
# Ensure the audio is in the correct shape (mono) | |
if waveform.dim() == 2: | |
waveform = waveform.mean(dim=0) | |
print(f"Waveform shape after processing: {waveform.shape}") | |
if waveform.numel() == 0: | |
raise ValueError("Waveform is empty. Please check the input audio file.") | |
input_length = waveform.shape[0] # Since waveform is 1D, access the length with shape[0] | |
chunk_length = int(CHUNK_LENGTH_S * SAMPLING_RATE) | |
# Corrected slicing for 1D tensor | |
chunks = [waveform[i:i + chunk_length] for i in range(0, input_length, chunk_length)] | |
# Initialize the processor | |
processor = get_processor() | |
model = get_model() | |
device = get_device() | |
results = [] | |
for chunk in chunks: | |
input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device) | |
with torch.no_grad(): | |
if task == "translate": | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate") | |
generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
else: | |
generated_ids = model.generate(input_features) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
results.extend(transcription) | |
# Clear GPU cache | |
torch.cuda.empty_cache() | |
return " ".join(results) | |
def process_audio(audio): | |
if audio is None: | |
return "No file uploaded", "", "" | |
detected_lang = detect_language(audio) | |
transcription = process_long_audio(audio, task="transcribe") | |
translation = process_long_audio(audio, task="translate", language=detected_lang) | |
return detected_lang, transcription, translation | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_audio, | |
inputs=gr.Audio(), | |
outputs=[ | |
gr.Textbox(label="Detected Language"), | |
gr.Textbox(label="Transcription", lines=5), | |
gr.Textbox(label="Translation", lines=5) | |
], | |
title="Audio Transcription and Translation", | |
description="Upload an audio file to detect its language, transcribe, and translate it.", | |
allow_flagging="never", | |
css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }" | |
) |