ASR_gradio / audio_processing.py
Kr08's picture
Update audio_processing.py
983e536 verified
raw
history blame
7.04 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 5
OVERLAP = 2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
whisper_model = None
diarization_pipeline = None
def load_models(model_size="small"):
global whisper_model, diarization_pipeline, device, compute_type
# Load Whisper model
try:
whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
except RuntimeError as e:
logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
device = "cpu"
compute_type = "int8"
whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
def load_diarization_pipeline():
global diarization_pipeline, device
# Try to initialize diarization pipeline
try:
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
if device == "cuda":
diarization_pipeline = diarization_pipeline.to(torch.device(device))
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
diarization_pipeline = None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
# Find the overlap
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
# Merge the segments
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
if 'translated' in segment:
merged[-1]['translated'] = merged_translated
else:
# If no significant overlap, append as a new segment
merged.append(segment)
return merged
# Helper function to get the most common speaker in a time range
def get_most_common_speaker(diarization_result, start_time, end_time):
speakers = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= end_time and turn.end >= start_time:
speakers.append(speaker)
return max(set(speakers), key=speakers.count) if speakers else "Unknown"
# Helper function to split long audio files
def split_audio(audio, max_duration=30):
sample_rate = 16000
max_samples = max_duration * sample_rate
if len(audio) <= max_samples:
return [audio]
splits = []
for i in range(0, len(audio), max_samples):
splits.append(audio[i:i+max_samples])
return splits
# Main processing function with optimizations
@spaces.GPU(duration=600)
def process_audio_optimized(audio_file, translate=False, model_size="small", use_diarization=True):
global whisper_model, diarization_pipeline
if whisper_model is None:
load_models(model_size)
start_time = time.time()
try:
audio = whisperx.load_audio(audio_file)
audio_splits = split_audio(audio)
# Perform diarization if requested and pipeline is available
diarization_result = None
if use_diarization:
if diarization_pipeline is None:
load_diarization_pipeline()
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
language_segments = []
final_segments = []
for i, audio_split in enumerate(audio_splits):
logger.info(f"Processing split {i+1}/{len(audio_splits)}")
result = whisper_model.transcribe(audio_split)
lang = result["language"]
for segment in result["segments"]:
segment_start = segment["start"] + (i * 30) # Adjust start time based on split
segment_end = segment["end"] + (i * 30) # Adjust end time based on split
speaker = "Unknown"
if diarization_result is not None:
speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
final_segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": segment["text"],
}
if translate:
translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
final_segment["translated"] = translation["text"]
final_segments.append(final_segment)
language_segments.append({
"language": lang,
"start": i * 30,
"end": min((i + 1) * 30, len(audio) / 16000)
})
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
# You can keep the original process_audio function for backwards compatibility
# or replace it with the optimized version
process_audio = process_audio_optimized