import gc import torch import torchaudio import numpy as np from transformers import ( Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM ) import spaces import logging from difflib import SequenceMatcher logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class AudioProcessor: def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): self.chunk_size = chunk_size self.overlap = overlap self.sample_rate = sample_rate self.previous_text = "" self.previous_lang = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_models(self): """Load all required models""" logger.info("Loading MMS models...") # Language identification model lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256") lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256") # Transcription model mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") # Translation model translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") return { 'lid': (lid_model, lid_processor), 'mms': (mms_model, mms_processor), 'translation': (translation_model, translation_tokenizer) } @spaces.GPU(duration=60) def identify_language(self, audio_chunk, models): """Identify language of audio chunk""" lid_model, lid_processor = models['lid'] inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") lid_model.to(self.device) with torch.no_grad(): outputs = lid_model(inputs.input_values.to(self.device)).logits lang_id = torch.argmax(outputs, dim=-1)[0].item() detected_lang = lid_model.config.id2label[lang_id] return detected_lang @spaces.GPU(duration=60) def transcribe_chunk(self, audio_chunk, language, models): """Transcribe audio chunk""" mms_model, mms_processor = models['mms'] mms_processor.tokenizer.set_target_lang(language) mms_model.load_adapter(language) mms_model.to(self.device) inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): outputs = mms_model(inputs.input_values.to(self.device)).logits ids = torch.argmax(outputs, dim=-1)[0] transcription = mms_processor.decode(ids) return transcription @spaces.GPU(duration=60) def translate_text(self, text, models): """Translate text to English""" translation_model, translation_tokenizer = models['translation'] inputs = translation_tokenizer(text, return_tensors="pt") inputs = inputs.to(self.device) translation_model.to(self.device) with torch.no_grad(): outputs = translation_model.generate( **inputs, forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), max_length=100 ) translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return translation def preprocess_audio(self, audio): """ Create overlapping chunks with improved timing logic """ chunk_samples = int(self.chunk_size * self.sample_rate) overlap_samples = int(self.overlap * self.sample_rate) chunks_with_times = [] start_idx = 0 while start_idx < len(audio): end_idx = min(start_idx + chunk_samples, len(audio)) # Add padding for first chunk if start_idx == 0: chunk = audio[start_idx:end_idx] padding = torch.zeros(int(1 * self.sample_rate)) chunk = torch.cat([padding, chunk]) else: # Include overlap from previous chunk actual_start = max(0, start_idx - overlap_samples) chunk = audio[actual_start:end_idx] # Pad if necessary if len(chunk) < chunk_samples: chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) # Adjust time ranges to account for overlaps chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) chunks_with_times.append({ 'chunk': chunk, 'start_time': start_idx / self.sample_rate, 'end_time': end_idx / self.sample_rate, 'transcribe_start': chunk_start_time, 'transcribe_end': chunk_end_time }) # Move to next chunk with smaller step size for better continuity start_idx += (chunk_samples - overlap_samples) return chunks_with_times @spaces.GPU(duration=60) def process_audio(self, audio_path, translate=False): """Main processing function""" try: # Load audio waveform, sample_rate = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0) else: waveform = waveform.squeeze(0) # Resample if necessary if sample_rate != self.sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate ) waveform = resampler(waveform) # if sample_rate != self.sample_rate: # waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform) # Load models models = self.load_models() # Process in chunks chunk_samples = int(self.chunk_size * self.sample_rate) overlap_samples = int(self.overlap * self.sample_rate) segments = [] language_segments = [] for i in range(0, len(waveform), chunk_samples - overlap_samples): chunk = waveform[i:i + chunk_samples] if len(chunk) < chunk_samples: chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) # Process chunk start_time = i / self.sample_rate end_time = (i + len(chunk)) / self.sample_rate # Identify language language = self.identify_language(chunk, models) # Record language segment language_segments.append({ "language": language, "start": start_time, "end": end_time }) # Transcribe transcription = self.transcribe_chunk(chunk, language, models) segment = { "start": start_time, "end": end_time, "language": language, "text": transcription, "speaker": "Speaker" # Simple speaker assignment } if translate: translation = self.translate_text(transcription, models) segment["translated"] = translation segments.append(segment) # Clean up GPU memory torch.cuda.empty_cache() gc.collect() # Merge nearby segments merged_segments = self.merge_segments(segments) return language_segments, merged_segments except Exception as e: logger.error(f"Error processing audio: {str(e)}") raise def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7): """Merge similar nearby segments""" if not segments: return segments merged = [] current = segments[0] for next_segment in segments[1:]: if (next_segment['start'] - current['end'] <= time_threshold and current['language'] == next_segment['language']): # Check text similarity matcher = SequenceMatcher(None, current['text'], next_segment['text']) similarity = matcher.ratio() if similarity > similarity_threshold: # Merge segments current['end'] = next_segment['end'] current['text'] = current['text'] + ' ' + next_segment['text'] if 'translated' in current and 'translated' in next_segment: current['translated'] = current['translated'] + ' ' + next_segment['translated'] else: merged.append(current) current = next_segment else: merged.append(current) current = next_segment merged.append(current) return merged