Raphael
Improve translation and subtitles sync
7d79fa6 unverified
raw
history blame
9.51 kB
import logging
import math
import os
import shutil
import tempfile
import time
from datasets import load_dataset
import gradio as gr
import moviepy.editor as mp
import numpy as np
import pysrt
import re
import torch
from transformers import pipeline
import yt_dlp
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True)
LOG = logging.getLogger(__name__)
CLIP_SECONDS = 20
SLICES = 4
# SLICE_DURATION = CLIP_SECONDS / SLICES
# At most 6 mins
MAX_CHUNKS = 45
SENTENCE_SPLIT = re.compile(r'([^.?!]*[.?!]+)([^.?!].*|$)')
asr_kwargs = {
"task": "automatic-speech-recognition",
"model": "openai/whisper-medium.en"
}
translator_kwargs = {
"task": "translation_en_to_fr",
"model": "Helsinki-NLP/opus-mt-en-fr"
}
summarizer_kwargs = {
"task": "summarization",
"model": "facebook/bart-large-cnn"
}
if torch.cuda.is_available():
LOG.info("GPU available")
asr_kwargs['device'] = 'cuda:0'
translator_kwargs['device'] = 'cuda:0'
summarizer_kwargs['device'] = 'cuda:0'
# All three models should fit together on a single T4 GPU
LOG.info("Fetching ASR model from the Hub if not already there")
asr = pipeline(**asr_kwargs)
LOG.info("Fetching translation model from the Hub if not already there")
translator = pipeline(**translator_kwargs)
LOG.info("Fetching summarization model from the Hub if not already there")
summarizer = pipeline(**summarizer_kwargs)
def demo(url: str, translate: bool):
# Free disk space leak
basedir = tempfile.mkdtemp()
LOG.info("Base directory %s", basedir)
video_path, video = download(url, os.path.join(basedir, 'video.mp4'))
audio_clips(video, basedir)
srt_file, full_transcription, summary = process_video(basedir, video.duration, translate)
return summary, srt_file, [video_path, srt_file], full_transcription
def download(url, dst):
LOG.info("Downloading provided url %s", url)
opts = {
'skip_download': False,
'overwrites': True,
'format': 'mp4',
'outtmpl': {'default': dst}
}
with yt_dlp.YoutubeDL(opts) as dl:
dl.download([url])
return dst, mp.VideoFileClip(dst)
def audiodir(basedir):
return os.path.join(basedir, 'audio')
def audio_clips(video: mp.VideoFileClip, basedir: str):
LOG.info("Building audio clips")
clips_dir = audiodir(basedir)
shutil.rmtree(clips_dir, ignore_errors=True)
os.makedirs(clips_dir, exist_ok=True)
audio = video.audio
end = audio.duration
digits = int(math.log(end / CLIP_SECONDS, 10)) + 1
for idx, i in enumerate(range(0, int(end), CLIP_SECONDS)):
sub_end = min(i+CLIP_SECONDS, end)
# print(sub_end)
sub_clip = audio.subclip(t_start=i, t_end=sub_end)
audio_file = os.path.join(clips_dir, f"audio_{idx:0{digits}d}" + ".ogg")
# audio_file = os.path.join(AUDIO_CLIPS, "audio_" + str(idx))
sub_clip.write_audiofile(audio_file, fps=16000)
def process_video(basedir: str, duration, translate: bool):
audio_dir = audiodir(basedir)
transcriptions = transcription(audio_dir, duration)
subs = translation(transcriptions, translate)
srt_file = build_srt_clips(subs, basedir)
summary = summarize(transcriptions, translate)
return srt_file, ' '.join([s['text'].strip() for s in subs]).strip(), summary
def transcription(audio_dir: str, duration):
LOG.info("Audio transcription")
# Not exact, nvm, doesn't need to be
chunks = int(duration / CLIP_SECONDS + 1)
chunks = min(chunks, MAX_CHUNKS)
LOG.debug("Loading audio clips dataset")
dataset = load_dataset("audiofolder", data_dir=audio_dir)
dataset = dataset['train']
dataset = dataset['audio'][0:chunks]
start = time.time()
transcriptions = []
for i, d in enumerate(np.array_split(dataset, 5)):
d = list(d)
LOG.info("ASR batch %d / 5, samples %d", i, len(d))
t = asr(d, max_new_tokens=10000)
transcriptions.extend(t)
transcriptions = [
{
'text': t['text'].strip(),
'start': i * CLIP_SECONDS * 1000,
'end': (i + 1) * CLIP_SECONDS * 1000
} for i, t in enumerate(transcriptions)
]
if transcriptions:
transcriptions[0]['start'] += 2500
# Will improve the translation
segments = segments_on_sentence_boundaries(transcriptions)
elapsed = time.time() - start
LOG.info("Transcription done, elapsed %.2f seconds", elapsed)
return segments
def segments_on_sentence_boundaries(segments):
LOG.info("Segmenting along sentence boundaries for better translations")
new_segments = []
i = 0
while i < len(segments):
s = segments[i]
text = s['text'].strip()
if not text:
i += 1
continue
if i == len(segments)-1:
new_segments.append(s)
break
next_s = segments[i+1]
next_text = next_s['text'].strip()
if not next_text or (text[-1] in ['.', '?', '!']):
new_segments.append(s)
i += 1
continue
m = SENTENCE_SPLIT.match(next_s['text'].strip())
if not m:
LOG.warning("Bad pattern matching on segment [%s], "
"this should not be possible", next_s['text'])
s['end'] = next_s['end']
s['text'] = '{} {}'.format(s['text'].strip(), next_s['text'].strip())
new_segments.append(s)
i += 2
else:
before = m.group(1)
after = m.group(2)
next_segment_duration = next_s['end'] - next_s['start']
ratio = len(before) / len(next_text)
add_time = int(next_segment_duration * ratio)
s['end'] = s['end'] + add_time
s['text'] = '{} {}'.format(text, before)
next_s['start'] = next_s['start'] + add_time
next_s['text'] = after.strip()
new_segments.append(s)
i += 1
return new_segments
def translation(transcriptions, translate):
translations_d = []
if translate:
LOG.info("Performing translation")
start = time.time()
translations = translator([t['text'] for t in transcriptions])
for i, t in enumerate(transcriptions):
tsl = t.copy()
tsl['text'] = translations[i]['translation_text'].strip()
translations_d.append(tsl)
elapsed = time.time() - start
LOG.info("Translation done, elapsed %.2f seconds", elapsed)
LOG.info('Translations %s', translations_d)
else:
translations_d = transcriptions
return translations_d
def summarize(transcriptions, translate):
LOG.info("Generating video summary")
whole_text = ' '.join([t['text'].strip() for t in transcriptions])
# word_count = len(whole_text.split())
summary = summarizer(whole_text)
# min_length=word_count // 4 + 1,
# max_length=word_count // 2 + 1)
summary = translation([{'text': summary[0]['summary_text']}], translate)[0]
return summary['text']
def segment_slices(subtitles: list[str]):
LOG.info("Building srt segments slices")
slices = []
for sub in subtitles:
chunks = np.array_split(sub['text'].split(' '), SLICES)
start = sub['start']
duration = sub['end'] - start
for i in range(0, SLICES):
s = {
'text': ' '.join(chunks[i]),
'start': start + i * duration / SLICES,
'end': start + (i+1) * duration / SLICES
}
slices.append(s)
return slices
def build_srt_clips(segments, basedir):
LOG.info("Generating subtitles")
segments = segment_slices(segments)
LOG.info("Building srt clips")
max_text_len = 45
subtitles = pysrt.SubRipFile()
for segment in segments:
start = segment['start']
end = segment['end']
text = segment['text']
text = text.strip()
if len(text) < max_text_len:
o = pysrt.SubRipItem()
o.start = pysrt.SubRipTime(0, 0, 0, start)
o.end = pysrt.SubRipTime(0, 0, 0, end)
o.text = text
subtitles.append(o)
else:
# Just split in two, should be ok in most cases
words = text.split()
o = pysrt.SubRipItem()
o.text = ' '.join(words[0:len(words)//2])
o.start = pysrt.SubRipTime(0, 0, 0, start)
chkpt = (start + end) / 2
o.end = pysrt.SubRipTime(0, 0, 0, chkpt)
subtitles.append(o)
o = pysrt.SubRipItem()
o.text = ' '.join(words[len(words)//2:])
o.start = pysrt.SubRipTime(0, 0, 0, chkpt)
o.end = pysrt.SubRipTime(0, 0, 0, end)
subtitles.append(o)
srt_path = os.path.join(basedir, 'video.srt')
subtitles.save(srt_path, encoding='utf-8')
LOG.info("Subtitles saved in srt file %s", srt_path)
return srt_path
iface = gr.Interface(
fn=demo,
inputs=[
gr.Text(value="https://youtu.be/tiZFewofSLM", label="English video url"),
gr.Checkbox(value=True, label='Translate to French')],
outputs=[
gr.Text(label="Video summary"),
gr.File(label="SRT file"),
gr.Video(label="Video with subtitles"),
gr.Text(label="Full transcription")
])
# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()