ASR / chunkedTranscriber.py
Kr08's picture
Update chunkedTranscriber.py
73c774a verified
import os
import gc
import sys
import time
import torch
import spaces
import torchaudio
import numpy as np
from df.enhance import enhance, init_df
from dotenv import load_dotenv
load_dotenv()
from scipy.signal import resample
from pyannote.audio import Pipeline
from difflib import SequenceMatcher
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
class ChunkedTranscriber:
def __init__(self, chunk_size=30, overlap=5, sample_rate=16000):
self.chunk_size = chunk_size
self.overlap = overlap
self.sample_rate = sample_rate
self.previous_text = ""
self.previous_lang = None
self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline()
def load_speaker_diarization_pipeline(self):
"""
Load the pre-trained speaker diarization pipeline from pyannote-audio.
"""
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN"))
return pipeline
@spaces.GPU(duration=180)
def diarize_audio(self, audio_path):
"""
Perform speaker diarization on the input audio.
"""
diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path})
return diarization_result
def load_lid_mms(self):
model_id = "facebook/mms-lid-256"
processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
return processor, model
@spaces.GPU(duration=180)
def language_identification(self, model, processor, chunk, device="cuda"):
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
model.to(device)
inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = model.config.id2label[lang_id]
del model
del inputs
torch.cuda.empty_cache()
gc.collect()
return detected_lang
def load_mms(self) :
model_id = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
return model, processor
@spaces.GPU(duration=180)
def mms_transcription(self, model, processor, chunk, device="cuda"):
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
model.to(device)
inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
del model
del inputs
torch.cuda.empty_cache()
gc.collect()
return transcription
def load_T2T_translation_model(self) :
model_id = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
return model, tokenizer
@spaces.GPU(duration=180)
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
# model, tokenizer = load_translation_model()
tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt')
translation_model.to(device)
tokenized_inputs.to(device)
translated_tokens = translation_model.generate(**tokenized_inputs,
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
max_length=100)
del translation_model
del tokenized_inputs
torch.cuda.empty_cache()
gc.collect()
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
def preprocess_audio(self, audio):
"""
Create overlapping chunks with improved timing logic
"""
chunk_samples = int(self.chunk_size * self.sample_rate)
overlap_samples = int(self.overlap * self.sample_rate)
chunks_with_times = []
start_idx = 0
while start_idx < len(audio):
end_idx = min(start_idx + chunk_samples, len(audio))
# Add padding for first chunk
if start_idx == 0:
chunk = audio[start_idx:end_idx]
padding = torch.zeros(int(1 * self.sample_rate))
chunk = torch.cat([padding, chunk])
else:
# Include overlap from previous chunk
actual_start = max(0, start_idx - overlap_samples)
chunk = audio[actual_start:end_idx]
# Pad if necessary
if len(chunk) < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
# Adjust time ranges to account for overlaps
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)
chunks_with_times.append({
'chunk': chunk,
'start_time': start_idx / self.sample_rate,
'end_time': end_idx / self.sample_rate,
'transcribe_start': chunk_start_time,
'transcribe_end': chunk_end_time
})
# Move to next chunk with smaller step size for better continuity
start_idx += (chunk_samples - overlap_samples)
return chunks_with_times
def merge_close_segments(self, results):
"""
Merge segments that are close in time and have the same language
"""
if not results:
return results
merged = []
current = results[0]
for next_segment in results[1:]:
# Skip empty segments
if not next_segment['text'].strip():
continue
# If segments are in the same language and close in time
if (current['detected_language'] == next_segment['detected_language'] and
abs(next_segment['start_time'] - current['end_time']) <= self.overlap):
# Merge the segments
current['text'] = current['text'] + ' ' + next_segment['text']
current['end_time'] = next_segment['end_time']
if 'translated' in current and 'translated' in next_segment:
current['translated'] = current['translated'] + ' ' + next_segment['translated']
else:
if current['text'].strip(): # Only add non-empty segments
merged.append(current)
current = next_segment
if current['text'].strip(): # Add the last segment if non-empty
merged.append(current)
return merged
def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3):
"""
Improved text cleaning with language awareness and better sentence boundary handling
"""
if not prev_text or not current_text:
return current_text
# If languages are different, don't try to merge
if prev_lang and current_lang and prev_lang != current_lang:
return current_text
# Split into words
prev_words = prev_text.split()
curr_words = current_text.split()
if len(prev_words) < 2 or len(curr_words) < 2:
return current_text
# Find matching sequences at the end of prev_text and start of current_text
matcher = SequenceMatcher(None, prev_words, curr_words)
matches = list(matcher.get_matching_blocks())
# Look for significant overlaps
best_overlap = 0
overlap_size = 0
for match in matches:
# Check if the match is at the start of current text
if match.b == 0 and match.size >= min_overlap:
if match.size > overlap_size:
best_overlap = match.size
overlap_size = match.size
if best_overlap > 0:
# Remove overlapping content while preserving sentence integrity
cleaned_words = curr_words[best_overlap:]
if not cleaned_words: # If everything was overlapping
return ""
return ' '.join(cleaned_words).strip()
return current_text
def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None):
"""
Process chunk with improved language handling
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Language detection
lid_processor, lid_model = self.load_lid_mms()
lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk'])
# Configure processor
mms_processor.tokenizer.set_target_lang(lid_lang)
mms_model.load_adapter(lid_lang)
# Transcribe
inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt")
inputs = inputs.to(device)
mms_model = mms_model.to(device)
with torch.no_grad():
outputs = mms_model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = mms_processor.decode(ids)
# Clean overlapping text with language awareness
cleaned_transcription = self.clean_overlapping_text(
transcription,
self.previous_text,
lid_lang,
self.previous_lang,
min_overlap=3
)
# Update previous state
self.previous_text = transcription
self.previous_lang = lid_lang
if not cleaned_transcription.strip():
return None
result = {
'start_time': chunk_data['start_time'],
'end_time': chunk_data['end_time'],
'text': cleaned_transcription,
'detected_language': lid_lang
}
# Handle translation
if translation_model and translation_tokenizer and cleaned_transcription.strip():
translation = self.text2text_translation(
translation_model,
translation_tokenizer,
cleaned_transcription
)
result['translated'] = translation
return result
except Exception as e:
print(f"Error processing chunk: {str(e)}")
return None
finally:
torch.cuda.empty_cache()
gc.collect()
def translate_text(self, text, translation_model, translation_tokenizer, device):
"""
Translate cleaned text using the provided translation model.
"""
tokenized_inputs = translation_tokenizer(text, return_tensors='pt')
tokenized_inputs = tokenized_inputs.to(device)
translation_model = translation_model.to(device)
translated_tokens = translation_model.generate(
**tokenized_inputs,
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
max_length=100
)
translation = translation_tokenizer.batch_decode(
translated_tokens,
skip_special_tokens=True
)[0]
del translation_model
del tokenized_inputs
torch.cuda.empty_cache()
gc.collect()
return translation
def audio_denoising():
model, df_state = init_df()
enhanced_audio = enhance(model, df_state, noisy_audio)
return enhanced_audio
def transcribe_audio(self, audio_path, translate=False):
"""
Main transcription function with improved segment merging
"""
# Perform speaker diarization
diarization_result = self.diarize_audio(audio_path)
# Extract speaker segments
speaker_segments = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
speaker_segments.append({
'start_time': turn.start,
'end_time': turn.end,
'speaker': speaker
})
audio = self.load_audio(audio_path)
chunks = self.preprocess_audio(audio)
mms_model, mms_processor = self.load_mms()
translation_model, translation_tokenizer = None, None
if translate:
translation_model, translation_tokenizer = self.load_T2T_translation_model()
# Process chunks
results = []
for chunk_data in chunks:
result = self.process_chunk(
chunk_data,
mms_model,
mms_processor,
translation_model,
translation_tokenizer
)
if result:
for segment in speaker_segments:
if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']):
result['speaker'] = segment['speaker']
break
results.append(result)
# results.append(result)
# Merge close segments and clean up
merged_results = self.merge_close_segments(results)
_translation = ""
_output = ""
for res in merged_results:
_translation+=res['translated']
_output+=f"{res['start_time']}-{res['end_time']} - Speaker: {res['speaker'].split('_')[1]} - Language: {res['detected_language']}\n Text: {res['text']}\n Translation: {res['translated']}\n\n"
logger.info(f"\n\n TRANSLATION: {_translation}")
return _translation, _output
def load_audio(self, audio_path):
"""
Load and preprocess audio file.
"""
waveform, sample_rate = torchaudio.load(audio_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0)
else:
waveform = waveform.squeeze(0)
# Resample if necessary
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.sample_rate
)
waveform = resampler(waveform)
return waveform.float()