import logging import math import os import shutil import time from datasets import load_dataset import gradio as gr import moviepy.editor as mp import numpy as np import pysrt import re import torch from transformers import pipeline import yt_dlp os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) LOG = logging.getLogger(__name__) BASEDIR = '/tmp/demo' os.makedirs(BASEDIR, exist_ok=True) CLIP_SECONDS = 20 SLICES = 4 # SLICE_DURATION = CLIP_SECONDS / SLICES # At most 6 mins MAX_CHUNKS = 45 SENTENCE_SPLIT = re.compile(r'([^.?!]*[.?!]+)([^.?!].*|$)') asr_kwargs = { "task": "automatic-speech-recognition", "model": "openai/whisper-medium.en" } translator_kwargs = { "task": "translation_en_to_fr", "model": "Helsinki-NLP/opus-mt-en-fr" } summarizer_kwargs = { "task": "summarization", "model": "facebook/bart-large-cnn" } if torch.cuda.is_available(): LOG.info("GPU available") asr_kwargs['device'] = 'cuda:0' translator_kwargs['device'] = 'cuda:0' summarizer_kwargs['device'] = 'cuda:0' # All three models should fit together on a single T4 GPU LOG.info("Fetching ASR model from the Hub if not already there") asr = pipeline(**asr_kwargs) LOG.info("Fetching translation model from the Hub if not already there") translator = pipeline(**translator_kwargs) LOG.info("Fetching summarization model from the Hub if not already there") summarizer = pipeline(**summarizer_kwargs) def demo(url: str, translate: bool): # Free disk space leak basedir = BASEDIR LOG.info("Base directory %s", basedir) video_path, video = download(url, os.path.join(basedir, 'video.mp4')) audio_clips(video, basedir) srt_file, full_transcription, summary = process_video(basedir, video.duration, translate) return summary, srt_file, [video_path, srt_file], full_transcription def download(url, dst): LOG.info("Downloading provided url %s", url) opts = { 'skip_download': False, 'overwrites': True, 'format': 'mp4', 'outtmpl': {'default': dst} } with yt_dlp.YoutubeDL(opts) as dl: dl.download([url]) return dst, mp.VideoFileClip(dst) def audiodir(basedir): return os.path.join(basedir, 'audio') def audio_clips(video: mp.VideoFileClip, basedir: str): LOG.info("Building audio clips") clips_dir = audiodir(basedir) shutil.rmtree(clips_dir, ignore_errors=True) os.makedirs(clips_dir, exist_ok=True) audio = video.audio end = audio.duration digits = int(math.log(end / CLIP_SECONDS, 10)) + 1 for idx, i in enumerate(range(0, int(end), CLIP_SECONDS)): sub_end = min(i+CLIP_SECONDS, end) # print(sub_end) sub_clip = audio.subclip(t_start=i, t_end=sub_end) audio_file = os.path.join(clips_dir, f"audio_{idx:0{digits}d}" + ".ogg") # audio_file = os.path.join(AUDIO_CLIPS, "audio_" + str(idx)) sub_clip.write_audiofile(audio_file, fps=16000) def process_video(basedir: str, duration, translate: bool): audio_dir = audiodir(basedir) transcriptions = transcription(audio_dir, duration) subs = translation(transcriptions, translate) srt_file = build_srt_clips(subs, basedir) summary = summarize(transcriptions, translate) return srt_file, ' '.join([s['text'].strip() for s in subs]).strip(), summary def transcription(audio_dir: str, duration): LOG.info("Audio transcription") # Not exact, nvm, doesn't need to be chunks = int(duration / CLIP_SECONDS + 1) chunks = min(chunks, MAX_CHUNKS) LOG.debug("Loading audio clips dataset") dataset = load_dataset("audiofolder", data_dir=audio_dir) dataset = dataset['train'] dataset = dataset['audio'][0:chunks] start = time.time() transcriptions = [] for i, d in enumerate(np.array_split(dataset, 5)): d = list(d) LOG.info("ASR batch %d / 5, samples %d", i, len(d)) t = asr(d, max_new_tokens=10000) transcriptions.extend(t) transcriptions = [ { 'text': t['text'].strip(), 'start': i * CLIP_SECONDS * 1000, 'end': (i + 1) * CLIP_SECONDS * 1000 } for i, t in enumerate(transcriptions) ] if transcriptions: transcriptions[0]['start'] += 2500 # Will improve the translation segments = segments_on_sentence_boundaries(transcriptions) elapsed = time.time() - start LOG.info("Transcription done, elapsed %.2f seconds", elapsed) return segments def segments_on_sentence_boundaries(segments): LOG.info("Segmenting along sentence boundaries for better translations") new_segments = [] i = 0 while i < len(segments): s = segments[i] text = s['text'].strip() if not text: i += 1 continue if i == len(segments)-1: new_segments.append(s) break next_s = segments[i+1] next_text = next_s['text'].strip() if not next_text or (text[-1] in ['.', '?', '!']): new_segments.append(s) i += 1 continue m = SENTENCE_SPLIT.match(next_s['text'].strip()) if not m: LOG.warning("Bad pattern matching on segment [%s], " "this should not be possible", next_s['text']) s['end'] = next_s['end'] s['text'] = '{} {}'.format(s['text'].strip(), next_s['text'].strip()) new_segments.append(s) i += 2 else: before = m.group(1) after = m.group(2) next_segment_duration = next_s['end'] - next_s['start'] ratio = len(before) / len(next_text) add_time = int(next_segment_duration * ratio) s['end'] = s['end'] + add_time s['text'] = '{} {}'.format(text, before) next_s['start'] = next_s['start'] + add_time next_s['text'] = after.strip() new_segments.append(s) i += 1 return new_segments def translation(transcriptions, translate): translations_d = [] if translate: LOG.info("Performing translation") start = time.time() translations = translator([t['text'] for t in transcriptions]) for i, t in enumerate(transcriptions): tsl = t.copy() tsl['text'] = translations[i]['translation_text'].strip() translations_d.append(tsl) elapsed = time.time() - start LOG.info("Translation done, elapsed %.2f seconds", elapsed) LOG.info('Translations %s', translations_d) else: translations_d = transcriptions return translations_d def summarize(transcriptions, translate): LOG.info("Generating video summary") whole_text = ' '.join([t['text'].strip() for t in transcriptions]) # word_count = len(whole_text.split()) summary = summarizer(whole_text) # min_length=word_count // 4 + 1, # max_length=word_count // 2 + 1) summary = translation([{'text': summary[0]['summary_text']}], translate)[0] return summary['text'] def segment_slices(subtitles: list[str]): LOG.info("Building srt segments slices") slices = [] for sub in subtitles: chunks = np.array_split(sub['text'].split(' '), SLICES) start = sub['start'] duration = sub['end'] - start for i in range(0, SLICES): s = { 'text': ' '.join(chunks[i]), 'start': start + i * duration / SLICES, 'end': start + (i+1) * duration / SLICES } slices.append(s) return slices def build_srt_clips(segments, basedir): LOG.info("Generating subtitles") segments = segment_slices(segments) LOG.info("Building srt clips") max_text_len = 45 subtitles = pysrt.SubRipFile() for segment in segments: start = segment['start'] end = segment['end'] text = segment['text'] text = text.strip() if len(text) < max_text_len: o = pysrt.SubRipItem() o.start = pysrt.SubRipTime(0, 0, 0, start) o.end = pysrt.SubRipTime(0, 0, 0, end) o.text = text subtitles.append(o) else: # Just split in two, should be ok in most cases words = text.split() o = pysrt.SubRipItem() o.text = ' '.join(words[0:len(words)//2]) o.start = pysrt.SubRipTime(0, 0, 0, start) chkpt = (start + end) / 2 o.end = pysrt.SubRipTime(0, 0, 0, chkpt) subtitles.append(o) o = pysrt.SubRipItem() o.text = ' '.join(words[len(words)//2:]) o.start = pysrt.SubRipTime(0, 0, 0, chkpt) o.end = pysrt.SubRipTime(0, 0, 0, end) subtitles.append(o) srt_path = os.path.join(basedir, 'video.srt') subtitles.save(srt_path, encoding='utf-8') LOG.info("Subtitles saved in srt file %s", srt_path) return srt_path iface = gr.Interface( fn=demo, inputs=[ gr.Text(value="https://youtu.be/tiZFewofSLM", label="English video url"), gr.Checkbox(value=True, label='Translate to French')], outputs=[ gr.Text(label="Video summary"), gr.File(label="SRT file"), gr.Video(label="Video with subtitles"), gr.Text(label="Full transcription") ]) # iface.launch(server_name="0.0.0.0", server_port=6443) iface.launch()