import whisperx import torch import numpy as np from scipy.signal import resample from pyannote.audio import Pipeline import os from dotenv import load_dotenv load_dotenv() import logging import time from difflib import SequenceMatcher import spaces hf_token = os.getenv("HF_TOKEN") CHUNK_LENGTH = 5 OVERLAP = 2 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for models device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" whisper_model = None diarization_pipeline = None def load_models(model_size="small"): global whisper_model, diarization_pipeline, device, compute_type # Load Whisper model try: whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type) except RuntimeError as e: logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}") device = "cpu" compute_type = "int8" whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type) def load_diarization_pipeline(): global diarization_pipeline, device # Try to initialize diarization pipeline try: diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token) if device == "cuda": diarization_pipeline = diarization_pipeline.to(torch.device(device)) except Exception as e: logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.") diarization_pipeline = None def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): chunks = [] for i in range(0, len(audio), chunk_size - overlap): chunk = audio[i:i+chunk_size] if len(chunk) < chunk_size: chunk = np.pad(chunk, (0, chunk_size - len(chunk))) chunks.append(chunk) return chunks def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7): merged = [] for segment in segments: if not merged or segment['start'] - merged[-1]['end'] > time_threshold: merged.append(segment) else: # Find the overlap matcher = SequenceMatcher(None, merged[-1]['text'], segment['text']) match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text'])) if match.size / len(segment['text']) > similarity_threshold: # Merge the segments merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:] merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:] merged[-1]['end'] = segment['end'] merged[-1]['text'] = merged_text if 'translated' in segment: merged[-1]['translated'] = merged_translated else: # If no significant overlap, append as a new segment merged.append(segment) return merged # Helper function to get the most common speaker in a time range def get_most_common_speaker(diarization_result, start_time, end_time): speakers = [] for turn, _, speaker in diarization_result.itertracks(yield_label=True): if turn.start <= end_time and turn.end >= start_time: speakers.append(speaker) return max(set(speakers), key=speakers.count) if speakers else "Unknown" # Helper function to split long audio files def split_audio(audio, max_duration=30): sample_rate = 16000 max_samples = max_duration * sample_rate if len(audio) <= max_samples: return [audio] splits = [] for i in range(0, len(audio), max_samples): splits.append(audio[i:i+max_samples]) return splits # Main processing function with optimizations @spaces.GPU(duration=600) def process_audio_optimized(audio_file, translate=False, model_size="small", use_diarization=True): global whisper_model, diarization_pipeline if whisper_model is None: load_models(model_size) start_time = time.time() try: audio = whisperx.load_audio(audio_file) audio_splits = split_audio(audio) # Perform diarization if requested and pipeline is available diarization_result = None if use_diarization: if diarization_pipeline is None: load_diarization_pipeline() if diarization_pipeline is not None: try: diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}) except Exception as e: logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.") language_segments = [] final_segments = [] for i, audio_split in enumerate(audio_splits): logger.info(f"Processing split {i+1}/{len(audio_splits)}") result = whisper_model.transcribe(audio_split) lang = result["language"] for segment in result["segments"]: segment_start = segment["start"] + (i * 30) # Adjust start time based on split segment_end = segment["end"] + (i * 30) # Adjust end time based on split speaker = "Unknown" if diarization_result is not None: speaker = get_most_common_speaker(diarization_result, segment_start, segment_end) final_segment = { "start": segment_start, "end": segment_end, "language": lang, "speaker": speaker, "text": segment["text"], } if translate: translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate") final_segment["translated"] = translation["text"] final_segments.append(final_segment) language_segments.append({ "language": lang, "start": i * 30, "end": min((i + 1) * 30, len(audio) / 16000) }) final_segments.sort(key=lambda x: x["start"]) merged_segments = merge_nearby_segments(final_segments) end_time = time.time() logger.info(f"Total processing time: {end_time - start_time:.2f} seconds") return language_segments, merged_segments except Exception as e: logger.error(f"An error occurred during audio processing: {str(e)}") raise # You can keep the original process_audio function for backwards compatibility # or replace it with the optimized version process_audio = process_audio_optimized