|
import logging |
|
import math |
|
import os |
|
import shutil |
|
import tempfile |
|
import time |
|
|
|
from datasets import load_dataset |
|
import gradio as gr |
|
import moviepy.editor as mp |
|
import numpy as np |
|
import pysrt |
|
import re |
|
import torch |
|
from transformers import pipeline |
|
import yt_dlp |
|
|
|
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) |
|
|
|
LOG = logging.getLogger(__name__) |
|
CLIP_SECONDS = 20 |
|
SLICES = 4 |
|
|
|
|
|
MAX_CHUNKS = 45 |
|
SENTENCE_SPLIT = re.compile(r'([^.?!]*[.?!]+)([^.?!].*|$)') |
|
|
|
asr_kwargs = { |
|
"task": "automatic-speech-recognition", |
|
"model": "openai/whisper-medium.en" |
|
} |
|
|
|
translator_kwargs = { |
|
"task": "translation_en_to_fr", |
|
"model": "Helsinki-NLP/opus-mt-en-fr" |
|
} |
|
|
|
summarizer_kwargs = { |
|
"task": "summarization", |
|
"model": "facebook/bart-large-cnn" |
|
} |
|
|
|
if torch.cuda.is_available(): |
|
LOG.info("GPU available") |
|
|
|
asr_kwargs['device'] = 'cuda:0' |
|
translator_kwargs['device'] = 'cuda:0' |
|
summarizer_kwargs['device'] = 'cuda:0' |
|
|
|
|
|
|
|
LOG.info("Fetching ASR model from the Hub if not already there") |
|
asr = pipeline(**asr_kwargs) |
|
|
|
LOG.info("Fetching translation model from the Hub if not already there") |
|
translator = pipeline(**translator_kwargs) |
|
|
|
LOG.info("Fetching summarization model from the Hub if not already there") |
|
summarizer = pipeline(**summarizer_kwargs) |
|
|
|
|
|
def demo(url: str, translate: bool): |
|
|
|
basedir = tempfile.mkdtemp() |
|
LOG.info("Base directory %s", basedir) |
|
video_path, video = download(url, os.path.join(basedir, 'video.mp4')) |
|
audio_clips(video, basedir) |
|
srt_file, full_transcription, summary = process_video(basedir, video.duration, translate) |
|
return summary, srt_file, [video_path, srt_file], full_transcription |
|
|
|
|
|
def download(url, dst): |
|
LOG.info("Downloading provided url %s", url) |
|
|
|
opts = { |
|
'skip_download': False, |
|
'overwrites': True, |
|
'format': 'mp4', |
|
'outtmpl': {'default': dst} |
|
} |
|
|
|
with yt_dlp.YoutubeDL(opts) as dl: |
|
dl.download([url]) |
|
|
|
return dst, mp.VideoFileClip(dst) |
|
|
|
|
|
def audiodir(basedir): |
|
return os.path.join(basedir, 'audio') |
|
|
|
|
|
def audio_clips(video: mp.VideoFileClip, basedir: str): |
|
|
|
LOG.info("Building audio clips") |
|
|
|
clips_dir = audiodir(basedir) |
|
shutil.rmtree(clips_dir, ignore_errors=True) |
|
os.makedirs(clips_dir, exist_ok=True) |
|
|
|
audio = video.audio |
|
end = audio.duration |
|
|
|
digits = int(math.log(end / CLIP_SECONDS, 10)) + 1 |
|
|
|
for idx, i in enumerate(range(0, int(end), CLIP_SECONDS)): |
|
sub_end = min(i+CLIP_SECONDS, end) |
|
|
|
sub_clip = audio.subclip(t_start=i, t_end=sub_end) |
|
audio_file = os.path.join(clips_dir, f"audio_{idx:0{digits}d}" + ".ogg") |
|
|
|
sub_clip.write_audiofile(audio_file, fps=16000) |
|
|
|
|
|
def process_video(basedir: str, duration, translate: bool): |
|
audio_dir = audiodir(basedir) |
|
transcriptions = transcription(audio_dir, duration) |
|
subs = translation(transcriptions, translate) |
|
srt_file = build_srt_clips(subs, basedir) |
|
summary = summarize(transcriptions, translate) |
|
return srt_file, ' '.join([s['text'].strip() for s in subs]).strip(), summary |
|
|
|
|
|
def transcription(audio_dir: str, duration): |
|
LOG.info("Audio transcription") |
|
|
|
chunks = int(duration / CLIP_SECONDS + 1) |
|
chunks = min(chunks, MAX_CHUNKS) |
|
|
|
LOG.debug("Loading audio clips dataset") |
|
|
|
dataset = load_dataset("audiofolder", data_dir=audio_dir) |
|
dataset = dataset['train'] |
|
dataset = dataset['audio'][0:chunks] |
|
|
|
start = time.time() |
|
transcriptions = [] |
|
for i, d in enumerate(np.array_split(dataset, 5)): |
|
d = list(d) |
|
LOG.info("ASR batch %d / 5, samples %d", i, len(d)) |
|
t = asr(d, max_new_tokens=10000) |
|
transcriptions.extend(t) |
|
|
|
transcriptions = [ |
|
{ |
|
'text': t['text'].strip(), |
|
'start': i * CLIP_SECONDS * 1000, |
|
'end': (i + 1) * CLIP_SECONDS * 1000 |
|
} for i, t in enumerate(transcriptions) |
|
] |
|
|
|
if transcriptions: |
|
transcriptions[0]['start'] += 2500 |
|
|
|
|
|
segments = segments_on_sentence_boundaries(transcriptions) |
|
|
|
elapsed = time.time() - start |
|
LOG.info("Transcription done, elapsed %.2f seconds", elapsed) |
|
return segments |
|
|
|
|
|
def segments_on_sentence_boundaries(segments): |
|
|
|
LOG.info("Segmenting along sentence boundaries for better translations") |
|
|
|
new_segments = [] |
|
i = 0 |
|
while i < len(segments): |
|
s = segments[i] |
|
text = s['text'].strip() |
|
if not text: |
|
i += 1 |
|
continue |
|
|
|
if i == len(segments)-1: |
|
new_segments.append(s) |
|
break |
|
|
|
next_s = segments[i+1] |
|
|
|
next_text = next_s['text'].strip() |
|
if not next_text or (text[-1] in ['.', '?', '!']): |
|
new_segments.append(s) |
|
i += 1 |
|
continue |
|
|
|
m = SENTENCE_SPLIT.match(next_s['text'].strip()) |
|
if not m: |
|
LOG.warning("Bad pattern matching on segment [%s], " |
|
"this should not be possible", next_s['text']) |
|
s['end'] = next_s['end'] |
|
s['text'] = '{} {}'.format(s['text'].strip(), next_s['text'].strip()) |
|
new_segments.append(s) |
|
i += 2 |
|
else: |
|
before = m.group(1) |
|
after = m.group(2) |
|
next_segment_duration = next_s['end'] - next_s['start'] |
|
ratio = len(before) / len(next_text) |
|
add_time = int(next_segment_duration * ratio) |
|
s['end'] = s['end'] + add_time |
|
s['text'] = '{} {}'.format(text, before) |
|
next_s['start'] = next_s['start'] + add_time |
|
next_s['text'] = after.strip() |
|
new_segments.append(s) |
|
i += 1 |
|
|
|
return new_segments |
|
|
|
|
|
def translation(transcriptions, translate): |
|
translations_d = [] |
|
if translate: |
|
LOG.info("Performing translation") |
|
start = time.time() |
|
translations = translator([t['text'] for t in transcriptions]) |
|
for i, t in enumerate(transcriptions): |
|
tsl = t.copy() |
|
tsl['text'] = translations[i]['translation_text'].strip() |
|
translations_d.append(tsl) |
|
elapsed = time.time() - start |
|
LOG.info("Translation done, elapsed %.2f seconds", elapsed) |
|
LOG.info('Translations %s', translations_d) |
|
else: |
|
translations_d = transcriptions |
|
return translations_d |
|
|
|
|
|
def summarize(transcriptions, translate): |
|
LOG.info("Generating video summary") |
|
whole_text = ' '.join([t['text'].strip() for t in transcriptions]) |
|
|
|
summary = summarizer(whole_text) |
|
|
|
|
|
summary = translation([{'text': summary[0]['summary_text']}], translate)[0] |
|
return summary['text'] |
|
|
|
|
|
def segment_slices(subtitles: list[str]): |
|
LOG.info("Building srt segments slices") |
|
slices = [] |
|
for sub in subtitles: |
|
chunks = np.array_split(sub['text'].split(' '), SLICES) |
|
start = sub['start'] |
|
duration = sub['end'] - start |
|
for i in range(0, SLICES): |
|
s = { |
|
'text': ' '.join(chunks[i]), |
|
'start': start + i * duration / SLICES, |
|
'end': start + (i+1) * duration / SLICES |
|
} |
|
slices.append(s) |
|
return slices |
|
|
|
|
|
def build_srt_clips(segments, basedir): |
|
|
|
LOG.info("Generating subtitles") |
|
segments = segment_slices(segments) |
|
|
|
LOG.info("Building srt clips") |
|
max_text_len = 45 |
|
subtitles = pysrt.SubRipFile() |
|
for segment in segments: |
|
start = segment['start'] |
|
end = segment['end'] |
|
text = segment['text'] |
|
text = text.strip() |
|
if len(text) < max_text_len: |
|
o = pysrt.SubRipItem() |
|
o.start = pysrt.SubRipTime(0, 0, 0, start) |
|
o.end = pysrt.SubRipTime(0, 0, 0, end) |
|
o.text = text |
|
subtitles.append(o) |
|
else: |
|
|
|
words = text.split() |
|
o = pysrt.SubRipItem() |
|
o.text = ' '.join(words[0:len(words)//2]) |
|
o.start = pysrt.SubRipTime(0, 0, 0, start) |
|
chkpt = (start + end) / 2 |
|
o.end = pysrt.SubRipTime(0, 0, 0, chkpt) |
|
subtitles.append(o) |
|
o = pysrt.SubRipItem() |
|
o.text = ' '.join(words[len(words)//2:]) |
|
o.start = pysrt.SubRipTime(0, 0, 0, chkpt) |
|
o.end = pysrt.SubRipTime(0, 0, 0, end) |
|
subtitles.append(o) |
|
|
|
srt_path = os.path.join(basedir, 'video.srt') |
|
subtitles.save(srt_path, encoding='utf-8') |
|
LOG.info("Subtitles saved in srt file %s", srt_path) |
|
return srt_path |
|
|
|
|
|
iface = gr.Interface( |
|
fn=demo, |
|
inputs=[ |
|
gr.Text(value="https://youtu.be/tiZFewofSLM", label="English video url"), |
|
gr.Checkbox(value=True, label='Translate to French')], |
|
outputs=[ |
|
gr.Text(label="Video summary"), |
|
gr.File(label="SRT file"), |
|
gr.Video(label="Video with subtitles"), |
|
gr.Text(label="Full transcription") |
|
]) |
|
|
|
|
|
iface.launch() |
|
|