|
from typing import Dict |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from tqdm import tqdm |
|
import whisper |
|
import torch |
|
import pytube |
|
import time |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
WHISPER_MODEL_NAME = "tiny.en" |
|
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f'whisper will use: {device}') |
|
|
|
t0 = time.time() |
|
self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME).to(device) |
|
t1 = time.time() |
|
|
|
total = t1 - t0 |
|
print(f'Finished loading whisper_model in {total} seconds') |
|
|
|
t0 = time.time() |
|
self.sentence_transformer_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME) |
|
t1 = time.time() |
|
|
|
total = t1 - t0 |
|
print(f'Finished loading sentence_transformer_model in {total} seconds') |
|
|
|
def __call__(self, data: Dict[str, str]) -> Dict: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the URL to video for transcription |
|
Return: |
|
A :obj:`dict`:. transcribed dict |
|
""" |
|
|
|
print('data', data) |
|
|
|
video_url = data.pop("video_url", None) |
|
segments = data.pop("segments", None) |
|
encoded_segments = {} |
|
if video_url: |
|
video_with_transcript = self.transcribe_video(video_url) |
|
encode_transcript = data.pop("encode_transcript", True) |
|
if encode_transcript: |
|
video_with_transcript['transcript']['segments'] = self.combine_transcripts(video_with_transcript) |
|
encoded_segments = { |
|
"encoded_segments": self.encode_sentences(video_with_transcript['transcript']['segments']) |
|
} |
|
return { |
|
**video_with_transcript, |
|
**encoded_segments |
|
} |
|
elif segments: |
|
encoded_segments = self.encode_sentences(segments) |
|
|
|
return { |
|
"encoded_segments": encoded_segments |
|
} |
|
|
|
def transcribe_video(self, video_url): |
|
decode_options = { |
|
|
|
|
|
|
|
|
|
"language": "en", |
|
"verbose": True |
|
} |
|
yt = pytube.YouTube(video_url) |
|
video_info = { |
|
'id': yt.video_id, |
|
'thumbnail': yt.thumbnail_url, |
|
'title': yt.title, |
|
'views': yt.views, |
|
'length': yt.length, |
|
|
|
|
|
'url': f"https://www.youtube.com/watch?v={yt.video_id}" |
|
} |
|
stream = yt.streams.filter(only_audio=True)[0] |
|
path_to_audio = f"{yt.video_id}.mp3" |
|
stream.download(filename=path_to_audio) |
|
t0 = time.time() |
|
transcript = self.whisper_model.transcribe(path_to_audio, **decode_options) |
|
t1 = time.time() |
|
for segment in transcript['segments']: |
|
|
|
segment.pop('tokens', None) |
|
|
|
total = t1 - t0 |
|
print(f'Finished transcription in {total} seconds') |
|
|
|
|
|
return {"transcript": transcript, 'video': video_info} |
|
|
|
def encode_sentences(self, transcripts, batch_size=64): |
|
""" |
|
Encoding all of our segments at once or storing them locally would require too much compute or memory. |
|
So we do it in batches of 64 |
|
:param transcripts: |
|
:param batch_size: |
|
:return: |
|
""" |
|
|
|
all_batches = [] |
|
for i in tqdm(range(0, len(transcripts), batch_size)): |
|
|
|
i_end = min(len(transcripts) - 1, i + batch_size) |
|
|
|
batch_meta = [{ |
|
**transcripts[x] |
|
} for x in range(i, i_end)] |
|
|
|
batch_text = [ |
|
row['text'] for row in transcripts[i:i_end] |
|
] |
|
|
|
batch_ids = [ |
|
row['id'] for row in transcripts[i:i_end] |
|
] |
|
|
|
batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist() |
|
|
|
batch_details = [ |
|
{ |
|
**batch_meta[x], |
|
'vectors':batch_vectors[x] |
|
} for x in range(0, len(batch_meta)) |
|
] |
|
all_batches.extend(batch_details) |
|
|
|
return all_batches |
|
|
|
@staticmethod |
|
def combine_transcripts(video, window=6, stride=3): |
|
""" |
|
|
|
:param video: |
|
:param window: number of sentences to combine |
|
:param stride: number of sentences to 'stride' over, used to create overlap |
|
:return: |
|
""" |
|
new_transcript_segments = [] |
|
|
|
video_info = video['video'] |
|
transcript_segments = video['transcript']['segments'] |
|
for i in tqdm(range(0, len(transcript_segments), stride)): |
|
i_end = min(len(transcript_segments) - 1, i + window) |
|
text = ' '.join(transcript['text'] |
|
for transcript in |
|
transcript_segments[i:i_end]) |
|
|
|
start = int(transcript_segments[i]['start']) |
|
end = int(transcript_segments[i]['end']) |
|
new_transcript_segments.append({ |
|
**video_info, |
|
**{ |
|
'start': start, |
|
'end': end, |
|
'title': video_info['title'], |
|
'text': text, |
|
'id': f"{video_info['id']}-t{start}", |
|
'url': f"https://youtu.be/{video_info['id']}?t={start}", |
|
'video_id': video_info['id'], |
|
} |
|
}) |
|
return new_transcript_segments |
|
|