Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gc | |
import sys | |
import time | |
import torch | |
import spaces | |
import torchaudio | |
import numpy as np | |
from scipy.signal import resample | |
from pyannote.audio import Pipeline | |
from dotenv import load_dotenv | |
load_dotenv() | |
from difflib import SequenceMatcher | |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM | |
from difflib import SequenceMatcher | |
import logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
class ChunkedTranscriber: | |
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.sample_rate = sample_rate | |
self.previous_text = "" | |
self.previous_lang = None | |
self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline() | |
def load_speaker_diarization_pipeline(self): | |
""" | |
Load the pre-trained speaker diarization pipeline from pyannote-audio. | |
""" | |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN")) | |
return pipeline | |
def diarize_audio(self, audio_path): | |
""" | |
Perform speaker diarization on the input audio. | |
""" | |
diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path}) | |
return diarization_result | |
def load_lid_mms(self): | |
model_id = "facebook/mms-lid-256" | |
processor = AutoFeatureExtractor.from_pretrained(model_id) | |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) | |
return processor, model | |
def language_identification(self, model, processor, chunk, device="cuda"): | |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") | |
model.to(device) | |
inputs.to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs).logits | |
lang_id = torch.argmax(outputs, dim=-1)[0].item() | |
detected_lang = model.config.id2label[lang_id] | |
del model | |
del inputs | |
torch.cuda.empty_cache() | |
gc.collect() | |
return detected_lang | |
def load_mms(self) : | |
model_id = "facebook/mms-1b-all" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = Wav2Vec2ForCTC.from_pretrained(model_id) | |
return model, processor | |
def mms_transcription(self, model, processor, chunk, device="cuda"): | |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") | |
model.to(device) | |
inputs.to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs).logits | |
ids = torch.argmax(outputs, dim=-1)[0] | |
transcription = processor.decode(ids) | |
del model | |
del inputs | |
torch.cuda.empty_cache() | |
gc.collect() | |
return transcription | |
def load_T2T_translation_model(self) : | |
model_id = "facebook/nllb-200-distilled-600M" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
return model, tokenizer | |
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"): | |
# model, tokenizer = load_translation_model() | |
tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt') | |
translation_model.to(device) | |
tokenized_inputs.to(device) | |
translated_tokens = translation_model.generate(**tokenized_inputs, | |
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
max_length=100) | |
del translation_model | |
del tokenized_inputs | |
torch.cuda.empty_cache() | |
gc.collect() | |
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
def preprocess_audio(self, audio): | |
""" | |
Create overlapping chunks with improved timing logic | |
""" | |
chunk_samples = int(self.chunk_size * self.sample_rate) | |
overlap_samples = int(self.overlap * self.sample_rate) | |
chunks_with_times = [] | |
start_idx = 0 | |
while start_idx < len(audio): | |
end_idx = min(start_idx + chunk_samples, len(audio)) | |
# Add padding for first chunk | |
if start_idx == 0: | |
chunk = audio[start_idx:end_idx] | |
padding = torch.zeros(int(1 * self.sample_rate)) | |
chunk = torch.cat([padding, chunk]) | |
else: | |
# Include overlap from previous chunk | |
actual_start = max(0, start_idx - overlap_samples) | |
chunk = audio[actual_start:end_idx] | |
# Pad if necessary | |
if len(chunk) < chunk_samples: | |
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
# Adjust time ranges to account for overlaps | |
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) | |
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) | |
chunks_with_times.append({ | |
'chunk': chunk, | |
'start_time': start_idx / self.sample_rate, | |
'end_time': end_idx / self.sample_rate, | |
'transcribe_start': chunk_start_time, | |
'transcribe_end': chunk_end_time | |
}) | |
# Move to next chunk with smaller step size for better continuity | |
start_idx += (chunk_samples - overlap_samples) | |
return chunks_with_times | |
def merge_close_segments(self, results): | |
""" | |
Merge segments that are close in time and have the same language | |
""" | |
if not results: | |
return results | |
merged = [] | |
current = results[0] | |
for next_segment in results[1:]: | |
# Skip empty segments | |
if not next_segment['text'].strip(): | |
continue | |
# If segments are in the same language and close in time | |
if (current['detected_language'] == next_segment['detected_language'] and | |
abs(next_segment['start_time'] - current['end_time']) <= self.overlap): | |
# Merge the segments | |
current['text'] = current['text'] + ' ' + next_segment['text'] | |
current['end_time'] = next_segment['end_time'] | |
if 'translated' in current and 'translated' in next_segment: | |
current['translated'] = current['translated'] + ' ' + next_segment['translated'] | |
else: | |
if current['text'].strip(): # Only add non-empty segments | |
merged.append(current) | |
current = next_segment | |
if current['text'].strip(): # Add the last segment if non-empty | |
merged.append(current) | |
return merged | |
def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3): | |
""" | |
Improved text cleaning with language awareness and better sentence boundary handling | |
""" | |
if not prev_text or not current_text: | |
return current_text | |
# If languages are different, don't try to merge | |
if prev_lang and current_lang and prev_lang != current_lang: | |
return current_text | |
# Split into words | |
prev_words = prev_text.split() | |
curr_words = current_text.split() | |
if len(prev_words) < 2 or len(curr_words) < 2: | |
return current_text | |
# Find matching sequences at the end of prev_text and start of current_text | |
matcher = SequenceMatcher(None, prev_words, curr_words) | |
matches = list(matcher.get_matching_blocks()) | |
# Look for significant overlaps | |
best_overlap = 0 | |
overlap_size = 0 | |
for match in matches: | |
# Check if the match is at the start of current text | |
if match.b == 0 and match.size >= min_overlap: | |
if match.size > overlap_size: | |
best_overlap = match.size | |
overlap_size = match.size | |
if best_overlap > 0: | |
# Remove overlapping content while preserving sentence integrity | |
cleaned_words = curr_words[best_overlap:] | |
if not cleaned_words: # If everything was overlapping | |
return "" | |
return ' '.join(cleaned_words).strip() | |
return current_text | |
def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None): | |
""" | |
Process chunk with improved language handling | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
try: | |
# Language detection | |
lid_processor, lid_model = self.load_lid_mms() | |
lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk']) | |
# Configure processor | |
mms_processor.tokenizer.set_target_lang(lid_lang) | |
mms_model.load_adapter(lid_lang) | |
# Transcribe | |
inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt") | |
inputs = inputs.to(device) | |
mms_model = mms_model.to(device) | |
with torch.no_grad(): | |
outputs = mms_model(**inputs).logits | |
ids = torch.argmax(outputs, dim=-1)[0] | |
transcription = mms_processor.decode(ids) | |
# Clean overlapping text with language awareness | |
cleaned_transcription = self.clean_overlapping_text( | |
transcription, | |
self.previous_text, | |
lid_lang, | |
self.previous_lang, | |
min_overlap=3 | |
) | |
# Update previous state | |
self.previous_text = transcription | |
self.previous_lang = lid_lang | |
if not cleaned_transcription.strip(): | |
return None | |
result = { | |
'start_time': chunk_data['start_time'], | |
'end_time': chunk_data['end_time'], | |
'text': cleaned_transcription, | |
'detected_language': lid_lang | |
} | |
# Handle translation | |
if translation_model and translation_tokenizer and cleaned_transcription.strip(): | |
translation = self.text2text_translation( | |
translation_model, | |
translation_tokenizer, | |
cleaned_transcription | |
) | |
result['translated'] = translation | |
return result | |
except Exception as e: | |
print(f"Error processing chunk: {str(e)}") | |
return None | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
def translate_text(self, text, translation_model, translation_tokenizer, device): | |
""" | |
Translate cleaned text using the provided translation model. | |
""" | |
tokenized_inputs = translation_tokenizer(text, return_tensors='pt') | |
tokenized_inputs = tokenized_inputs.to(device) | |
translation_model = translation_model.to(device) | |
translated_tokens = translation_model.generate( | |
**tokenized_inputs, | |
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
max_length=100 | |
) | |
translation = translation_tokenizer.batch_decode( | |
translated_tokens, | |
skip_special_tokens=True | |
)[0] | |
del translation_model | |
del tokenized_inputs | |
torch.cuda.empty_cache() | |
gc.collect() | |
return translation | |
def transcribe_audio(self, audio_path, translate=False): | |
""" | |
Main transcription function with improved segment merging | |
""" | |
# Perform speaker diarization | |
diarization_result = self.diarize_audio(audio_path) | |
# Extract speaker segments | |
speaker_segments = [] | |
for turn, _, speaker in diarization_result.itertracks(yield_label=True): | |
speaker_segments.append({ | |
'start_time': turn.start, | |
'end_time': turn.end, | |
'speaker': speaker | |
}) | |
audio = self.load_audio(audio_path) | |
chunks = self.preprocess_audio(audio) | |
mms_model, mms_processor = self.load_mms() | |
translation_model, translation_tokenizer = None, None | |
if translate: | |
translation_model, translation_tokenizer = self.load_T2T_translation_model() | |
# Process chunks | |
results = [] | |
for chunk_data in chunks: | |
result = self.process_chunk( | |
chunk_data, | |
mms_model, | |
mms_processor, | |
translation_model, | |
translation_tokenizer | |
) | |
if result: | |
for segment in speaker_segments: | |
if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']): | |
result['speaker'] = segment['speaker'] | |
break | |
results.append(result) | |
# results.append(result) | |
# Merge close segments and clean up | |
merged_results = self.merge_close_segments(results) | |
_translation = "" | |
_output = "" | |
for res in merged_results: | |
_translation+=res['translated'] | |
_output+=f"{res['start_time']}-{res['end_time']} - Speaker: {res['speaker'].split('_')[1]} - Language: {res['detected_language']}\n Text: {res['text']}\n Translation: {res['translated']}\n\n" | |
logger.info(f"\n\n TRANSLATION: {_translation}") | |
return _translation, _output | |
def load_audio(self, audio_path): | |
""" | |
Load and preprocess audio file. | |
""" | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0) | |
else: | |
waveform = waveform.squeeze(0) | |
# Resample if necessary | |
if sample_rate != self.sample_rate: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, | |
new_freq=self.sample_rate | |
) | |
waveform = resampler(waveform) | |
return waveform.float() |